[cleanup] remove redundant code in SummarizationDataset (#5119)

This commit is contained in:
Sam Shleifer 2020-06-18 20:34:48 -04:00 committed by GitHub
parent 5f721ad6e4
commit 2db1e2f415
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 10 deletions

View File

@ -13,8 +13,6 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
prefix="",
):
super().__init__()
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
prefix=prefix,
tok_name=tok_name,
)
if type_path == "train":
tgt_path = os.path.join(data_dir, type_path + ".target")
else:
tgt_path = os.path.join(data_dir, type_path + ".target")
tgt_path = os.path.join(data_dir, type_path + ".target")
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
if n_obs is not None:
self.source = self.source[:n_obs]
self.target = self.target[:n_obs]