[cleanup] remove redundant code in SummarizationDataset (#5119)
This commit is contained in:
parent
5f721ad6e4
commit
2db1e2f415
|
@ -13,8 +13,6 @@ from torch import nn
|
|||
from torch.utils.data import Dataset, Sampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartTokenizer
|
||||
|
||||
|
||||
def encode_file(
|
||||
tokenizer,
|
||||
|
@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
|
|||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
|
||||
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
||||
self.source = encode_file(
|
||||
tokenizer,
|
||||
os.path.join(data_dir, type_path + ".source"),
|
||||
|
@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
|
|||
prefix=prefix,
|
||||
tok_name=tok_name,
|
||||
)
|
||||
if type_path == "train":
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
else:
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
|
||||
tgt_path = os.path.join(data_dir, type_path + ".target")
|
||||
self.target = encode_file(
|
||||
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
||||
)
|
||||
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
|
||||
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
|
||||
if n_obs is not None:
|
||||
self.source = self.source[:n_obs]
|
||||
self.target = self.target[:n_obs]
|
||||
|
|
Loading…
Reference in New Issue