diff --git a/examples/seq2seq/test_datasets.py b/examples/seq2seq/test_datasets.py index aaf94fa5e0..dbd6769ad2 100644 --- a/examples/seq2seq/test_datasets.py +++ b/examples/seq2seq/test_datasets.py @@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs(): ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False)) ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False)) assert ids1.intersection(ids2) == set() + + +@pytest.mark.parametrize( + "tok_name", + [ + MBART_TINY, + MARIAN_TINY, + T5_TINY, + BART_TINY, + PEGASUS_XSUM, + ], +) +def test_dataset_kwargs(tok_name): + tokenizer = AutoTokenizer.from_pretrained(tok_name) + if tok_name == MBART_TINY: + train_dataset = Seq2SeqDataset( + tokenizer, + data_dir=make_test_data_dir(), + type_path="train", + max_source_length=4, + max_target_length=8, + src_lang="EN", + tgt_lang="FR", + ) + kwargs = train_dataset.dataset_kwargs + assert "src_lang" in kwargs and "tgt_lang" in kwargs + else: + train_dataset = Seq2SeqDataset( + tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8 + ) + kwargs = train_dataset.dataset_kwargs + assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs + assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0 diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index ac1629c0c5..43f5caf05f 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): return loss, nll_loss -def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): - """Only used by LegacyDataset""" - extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} - return tokenizer( - [line], - max_length=max_length, - padding="max_length" if pad_to_max_length else None, - truncation=True, - return_tensors=return_tensors, - **extra_kw, - ) - - def lmap(f: Callable, x: Iterable) -> List: """list(map(f, x))""" return list(map(f, x)) @@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset): max_target_length, type_path="train", n_obs=None, - src_lang=None, - tgt_lang=None, prefix="", + **dataset_kwargs ): super().__init__() self.src_file = Path(data_dir).joinpath(type_path + ".source") @@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset): if n_obs is not None: self.src_lens = self.src_lens[:n_obs] self.pad_token_id = self.tokenizer.pad_token_id - self.src_lang = src_lang - self.tgt_lang = tgt_lang - self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer) + self.dataset_kwargs = dataset_kwargs + dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {}) def __len__(self): return len(self.src_lens) @@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") assert source_line, f"empty source line for index {index}" assert tgt_line, f"empty tgt line for index {index}" - source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) - target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) + source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length) + target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length) source_ids = source_inputs["input_ids"].squeeze() target_ids = target_inputs["input_ids"].squeeze() @@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset): "labels": target_ids, } + def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): + """Only used by LegacyDataset""" + return tokenizer( + [line], + max_length=max_length, + padding="max_length" if pad_to_max_length else None, + truncation=True, + return_tensors=return_tensors, + **self.dataset_kwargs, + ) + def collate_fn(self, batch) -> Dict[str, torch.Tensor]: input_ids = torch.stack([x["input_ids"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch]) @@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): """Call prepare_seq2seq_batch.""" batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch( [x["src_texts"] for x in batch], - src_lang=self.src_lang, tgt_texts=[x["tgt_texts"] for x in batch], - tgt_lang=self.tgt_lang, max_length=self.max_source_length, max_target_length=self.max_target_length, return_tensors="pt", - add_prefix_space=self.add_prefix_space, + **self.dataset_kwargs, ).data batch_encoding["ids"] = torch.tensor([x["id"] for x in batch]) return batch_encoding