Seq2SeqDataset: avoid passing src_lang everywhere (#7470)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
parent
08939cfdf7
commit
c031d01023
|
@ -185,3 +185,36 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
|
||||||
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
|
||||||
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
|
||||||
assert ids1.intersection(ids2) == set()
|
assert ids1.intersection(ids2) == set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tok_name",
|
||||||
|
[
|
||||||
|
MBART_TINY,
|
||||||
|
MARIAN_TINY,
|
||||||
|
T5_TINY,
|
||||||
|
BART_TINY,
|
||||||
|
PEGASUS_XSUM,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dataset_kwargs(tok_name):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tok_name)
|
||||||
|
if tok_name == MBART_TINY:
|
||||||
|
train_dataset = Seq2SeqDataset(
|
||||||
|
tokenizer,
|
||||||
|
data_dir=make_test_data_dir(),
|
||||||
|
type_path="train",
|
||||||
|
max_source_length=4,
|
||||||
|
max_target_length=8,
|
||||||
|
src_lang="EN",
|
||||||
|
tgt_lang="FR",
|
||||||
|
)
|
||||||
|
kwargs = train_dataset.dataset_kwargs
|
||||||
|
assert "src_lang" in kwargs and "tgt_lang" in kwargs
|
||||||
|
else:
|
||||||
|
train_dataset = Seq2SeqDataset(
|
||||||
|
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8
|
||||||
|
)
|
||||||
|
kwargs = train_dataset.dataset_kwargs
|
||||||
|
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
|
||||||
|
assert len(kwargs) == 1 if tok_name == BART_TINY else len(kwargs) == 0
|
||||||
|
|
|
@ -52,19 +52,6 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||||
return loss, nll_loss
|
return loss, nll_loss
|
||||||
|
|
||||||
|
|
||||||
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
|
||||||
"""Only used by LegacyDataset"""
|
|
||||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
|
||||||
return tokenizer(
|
|
||||||
[line],
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length" if pad_to_max_length else None,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors=return_tensors,
|
|
||||||
**extra_kw,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def lmap(f: Callable, x: Iterable) -> List:
|
def lmap(f: Callable, x: Iterable) -> List:
|
||||||
"""list(map(f, x))"""
|
"""list(map(f, x))"""
|
||||||
return list(map(f, x))
|
return list(map(f, x))
|
||||||
|
@ -97,9 +84,8 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||||
max_target_length,
|
max_target_length,
|
||||||
type_path="train",
|
type_path="train",
|
||||||
n_obs=None,
|
n_obs=None,
|
||||||
src_lang=None,
|
|
||||||
tgt_lang=None,
|
|
||||||
prefix="",
|
prefix="",
|
||||||
|
**dataset_kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||||
|
@ -120,9 +106,8 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||||
if n_obs is not None:
|
if n_obs is not None:
|
||||||
self.src_lens = self.src_lens[:n_obs]
|
self.src_lens = self.src_lens[:n_obs]
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.src_lang = src_lang
|
self.dataset_kwargs = dataset_kwargs
|
||||||
self.tgt_lang = tgt_lang
|
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
|
||||||
self.add_prefix_space = isinstance(self.tokenizer, BartTokenizer)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.src_lens)
|
return len(self.src_lens)
|
||||||
|
@ -182,8 +167,8 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||||
assert source_line, f"empty source line for index {index}"
|
assert source_line, f"empty source line for index {index}"
|
||||||
assert tgt_line, f"empty tgt line for index {index}"
|
assert tgt_line, f"empty tgt line for index {index}"
|
||||||
source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
|
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
|
||||||
target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
||||||
|
|
||||||
source_ids = source_inputs["input_ids"].squeeze()
|
source_ids = source_inputs["input_ids"].squeeze()
|
||||||
target_ids = target_inputs["input_ids"].squeeze()
|
target_ids = target_inputs["input_ids"].squeeze()
|
||||||
|
@ -194,6 +179,17 @@ class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||||
"labels": target_ids,
|
"labels": target_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||||
|
"""Only used by LegacyDataset"""
|
||||||
|
return tokenizer(
|
||||||
|
[line],
|
||||||
|
max_length=max_length,
|
||||||
|
padding="max_length" if pad_to_max_length else None,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
**self.dataset_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||||
|
@ -224,13 +220,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||||
"""Call prepare_seq2seq_batch."""
|
"""Call prepare_seq2seq_batch."""
|
||||||
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||||
[x["src_texts"] for x in batch],
|
[x["src_texts"] for x in batch],
|
||||||
src_lang=self.src_lang,
|
|
||||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||||
tgt_lang=self.tgt_lang,
|
|
||||||
max_length=self.max_source_length,
|
max_length=self.max_source_length,
|
||||||
max_target_length=self.max_target_length,
|
max_target_length=self.max_target_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_prefix_space=self.add_prefix_space,
|
**self.dataset_kwargs,
|
||||||
).data
|
).data
|
||||||
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
||||||
return batch_encoding
|
return batch_encoding
|
||||||
|
|
Loading…
Reference in New Issue