Seq2SeqDataset: avoid passing src_lang everywhere (#7470)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Amanpreet Singh 2020-09-30 13:27:48 -04:00 committed by GitHub
parent 08939cfdf7
commit c031d01023
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 23 deletions

View File

@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set()
@pytest.mark.parametrize(
"tok_name",
[
MBART_TINY,
MARIAN_TINY,
T5_TINY,
BART_TINY,
PEGASUS_XSUM,
],
)
def test_dataset_kwargs(tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset(
tokenizer,
data_dir=make_test_data_dir(),
type_path="train",
max_source_length=4,
max_target_length=8,
src_lang="EN",
tgt_lang="FR",
)
kwargs = train_dataset.dataset_kwargs
assert "src_lang" in kwargs and "tgt_lang" in kwargs
else:
train_dataset = Seq2SeqDataset(
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
)
kwargs = train_dataset.dataset_kwargs
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0

View File

@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
return loss, nll_loss
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**extra_kw,
)
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
max_target_length,
type_path="train",
n_obs=None,
src_lang=None,
tgt_lang=None,
prefix="",
**dataset_kwargs
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
if n_obs is not None:
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
self.dataset_kwargs = dataset_kwargs
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
def __len__(self):
return len(self.src_lens)
@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
"labels": target_ids,
}
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
"""Only used by LegacyDataset"""
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
**self.dataset_kwargs,
)
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
"""Call prepare_seq2seq_batch."""
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.tgt_lang,
max_length=self.max_source_length,
max_target_length=self.max_target_length,
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
**self.dataset_kwargs,
).data
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
return batch_encoding