diff --git a/examples/summarization/bart/run_bart_sum.py b/examples/summarization/bart/run_bart_sum.py index 31836ce477..b580851da5 100644 --- a/examples/summarization/bart/run_bart_sum.py +++ b/examples/summarization/bart/run_bart_sum.py @@ -19,30 +19,20 @@ class BartSystem(BaseTransformer): mode = "language-modeling" def __init__(self, hparams): - super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode) + super().__init__(hparams, num_labels=None, mode=self.mode, output_past=False) - def forward( - self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None - ): + def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): return self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - lm_labels=lm_labels, + input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels, ) def _step(self, batch): - y = batch["target_ids"] + pad_token_id = self.tokenizer.pad_token_id + source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"] y_ids = y[:, :-1].contiguous() lm_labels = y[:, 1:].clone() - lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100 - outputs = self( - input_ids=batch["source_ids"], - attention_mask=batch["source_mask"], - decoder_input_ids=y_ids, - lm_labels=lm_labels, - ) + lm_labels[y[:, 1:] == pad_token_id] = -100 + outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,) loss = outputs[0] @@ -64,9 +54,13 @@ class BartSystem(BaseTransformer): return {"avg_val_loss": avg_loss, "log": tensorboard_logs} def test_step(self, batch, batch_idx): + # NOTE: this generation will not use the cache. + pad_token_id = self.tokenizer.pad_token_id + source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id) + # NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py. generated_ids = self.model.generate( - batch["source_ids"], - attention_mask=batch["source_mask"], + source_ids, + source_mask, num_beams=1, max_length=80, repetition_penalty=2.5, @@ -77,10 +71,7 @@ class BartSystem(BaseTransformer): self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids ] - target = [ - self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) - for t in batch["target_ids"] - ] + target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y] loss = self._step(batch) return {"val_loss": loss, "preds": preds, "target": target} @@ -101,11 +92,21 @@ class BartSystem(BaseTransformer): return self.test_end(outputs) - def train_dataloader(self): - train_dataset = SummarizationDataset( - self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length + @property + def dataset_kwargs(self): + return dict( + data_dir=self.hparams.data_dir, + max_source_length=self.hparams.max_source_length, + max_target_length=self.hparams.max_target_length, ) - dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size) + + def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader: + dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs) + dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) + return dataloader + + def train_dataloader(self) -> DataLoader: + dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size) t_total = ( (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) // self.hparams.gradient_accumulation_steps @@ -117,29 +118,30 @@ class BartSystem(BaseTransformer): self.lr_scheduler = scheduler return dataloader - def val_dataloader(self): - val_dataset = SummarizationDataset( - self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length - ) - return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size) + def val_dataloader(self) -> DataLoader: + return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) - def test_dataloader(self): - test_dataset = SummarizationDataset( - self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length - ) - return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size) + def test_dataloader(self) -> DataLoader: + return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) @staticmethod def add_model_specific_args(parser, root_dir): BaseTransformer.add_model_specific_args(parser, root_dir) # Add BART specific options parser.add_argument( - "--max_seq_length", + "--max_source_length", default=1024, type=int, help="The maximum total input sequence length after tokenization. Sequences longer " "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--max_target_length", + default=56, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", + ) parser.add_argument( "--data_dir", @@ -158,7 +160,7 @@ if __name__ == "__main__": args = parser.parse_args() # If output_dir not provided, a folder will be generated in pwd - if args.output_dir is None: + if not args.output_dir: args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",) os.makedirs(args.output_dir) diff --git a/examples/summarization/bart/test_bart_examples.py b/examples/summarization/bart/test_bart_examples.py index 40be3b5668..b136edfdb7 100644 --- a/examples/summarization/bart/test_bart_examples.py +++ b/examples/summarization/bart/test_bart_examples.py @@ -5,28 +5,57 @@ import unittest from pathlib import Path from unittest.mock import patch +from torch.utils.data import DataLoader + +from transformers import BartTokenizer + from .evaluate_cnn import run_generate +from .utils import SummarizationDataset -articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] - logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger() +def _dump_articles(path: Path, articles: list): + with path.open("w") as f: + f.write("\n".join(articles)) + + class TestBartExamples(unittest.TestCase): def test_bart_cnn_cli(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" - with tmp.open("w") as f: - f.write("\n".join(articles)) - output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" - + articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] + _dump_articles(tmp, articles) testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] - with patch.object(sys, "argv", testargs): run_generate() - self.assertTrue(Path(output_file_name).exists()) + self.assertTrue(output_file_name.exists()) + + def test_bart_summarization_dataset(self): + tmp_dir = Path(tempfile.gettempdir()) + articles = [" Sam ate lunch today", "Sams lunch ingredients"] + summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] + _dump_articles((tmp_dir / "train.source"), articles) + _dump_articles((tmp_dir / "train.target"), summaries) + tokenizer = BartTokenizer.from_pretrained("bart-large") + max_len_source = max(len(tokenizer.encode(a)) for a in articles) + max_len_target = max(len(tokenizer.encode(a)) for a in summaries) + trunc_target = 4 + train_dataset = SummarizationDataset( + tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, + ) + dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) + for batch in dataloader: + self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape) + # show that articles were trimmed. + self.assertEqual(batch["source_ids"].shape[1], max_len_source) + self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly + + # show that targets were truncated + self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated + self.assertGreater(max_len_target, trunc_target) # Truncated diff --git a/examples/summarization/bart/utils.py b/examples/summarization/bart/utils.py index fbe3c2d4e1..b3d9d0e84b 100644 --- a/examples/summarization/bart/utils.py +++ b/examples/summarization/bart/utils.py @@ -1,35 +1,35 @@ import os +import torch from torch.utils.data import Dataset +from transformers.tokenization_utils import trim_batch + + +def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"): + examples = [] + with open(data_path, "r") as f: + for text in f.readlines(): + tokenized = tokenizer.batch_encode_plus( + [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, + ) + examples.append(tokenized) + return examples + class SummarizationDataset(Dataset): - def __init__(self, tokenizer, data_dir="./cnn-dailymail/cnn_dm/", type_path="train", block_size=1024): - super(SummarizationDataset,).__init__() + def __init__( + self, + tokenizer, + data_dir="./cnn-dailymail/cnn_dm/", + type_path="train", + max_source_length=1024, + max_target_length=56, + ): + super().__init__() self.tokenizer = tokenizer - - self.source = [] - self.target = [] - - print("loading " + type_path + " source.") - - with open(os.path.join(data_dir, type_path + ".source"), "r") as f: - for text in f.readlines(): # each text is a line and a full story - tokenized = tokenizer.batch_encode_plus( - [text], max_length=block_size, pad_to_max_length=True, return_tensors="pt" - ) - self.source.append(tokenized) - f.close() - - print("loading " + type_path + " target.") - - with open(os.path.join(data_dir, type_path + ".target"), "r") as f: - for text in f.readlines(): # each text is a line and a summary - tokenized = tokenizer.batch_encode_plus( - [text], max_length=56, pad_to_max_length=True, return_tensors="pt" - ) - self.target.append(tokenized) - f.close() + self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) + self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) def __len__(self): return len(self.source) @@ -37,7 +37,20 @@ class SummarizationDataset(Dataset): def __getitem__(self, index): source_ids = self.source[index]["input_ids"].squeeze() target_ids = self.target[index]["input_ids"].squeeze() - - src_mask = self.source[index]["attention_mask"].squeeze() # might need to squeeze - + src_mask = self.source[index]["attention_mask"].squeeze() return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids} + + @staticmethod + def trim_seq2seq_batch(batch, pad_token_id): + y = trim_batch(batch["target_ids"], pad_token_id) + source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) + return source_ids, source_mask, y + + def collate_fn(self, batch): + input_ids = torch.stack([x["source_ids"] for x in batch]) + masks = torch.stack([x["source_mask"] for x in batch]) + target_ids = torch.stack([x["target_ids"] for x in batch]) + pad_token_id = self.tokenizer.pad_token_id + y = trim_batch(target_ids, pad_token_id) + source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) + return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y} diff --git a/examples/transformer_base.py b/examples/transformer_base.py index 739c5a3e40..a744c91594 100644 --- a/examples/transformer_base.py +++ b/examples/transformer_base.py @@ -47,27 +47,29 @@ def set_seed(args): class BaseTransformer(pl.LightningModule): - def __init__(self, hparams, num_labels=None, mode="base"): + def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs): "Initialize a model." super(BaseTransformer, self).__init__() self.hparams = hparams + cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None self.hparams.model_type = self.hparams.model_type.lower() config = AutoConfig.from_pretrained( self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, **({"num_labels": num_labels} if num_labels is not None else {}), - cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, + cache_dir=cache_dir, + **config_kwargs, ) tokenizer = AutoTokenizer.from_pretrained( self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, do_lower_case=self.hparams.do_lower_case, - cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, + cache_dir=cache_dir, ) model = MODEL_MODES[mode].from_pretrained( self.hparams.model_name_or_path, from_tf=bool(".ckpt" in self.hparams.model_name_or_path), config=config, - cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None, + cache_dir=cache_dir, ) self.config, self.tokenizer, self.model = config, tokenizer, model