[examples] SummarizationDataset cleanup (#3451)

This commit is contained in:
Sam Shleifer 2020-04-07 19:05:58 -04:00 committed by GitHub
parent b0ad069517
commit e344e3d402
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 125 additions and 79 deletions

View File

@ -19,30 +19,20 @@ class BartSystem(BaseTransformer):
mode = "language-modeling"
def __init__(self, hparams):
super(BartSystem, self).__init__(hparams, num_labels=None, mode=self.mode)
super().__init__(hparams, num_labels=None, mode=self.mode, output_past=False)
def forward(
self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None
):
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
return self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels,
)
def _step(self, batch):
y = batch["target_ids"]
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == self.tokenizer.pad_token_id] = -100
outputs = self(
input_ids=batch["source_ids"],
attention_mask=batch["source_mask"],
decoder_input_ids=y_ids,
lm_labels=lm_labels,
)
lm_labels[y[:, 1:] == pad_token_id] = -100
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,)
loss = outputs[0]
@ -64,9 +54,13 @@ class BartSystem(BaseTransformer):
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
def test_step(self, batch, batch_idx):
# NOTE: this generation will not use the cache.
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
generated_ids = self.model.generate(
batch["source_ids"],
attention_mask=batch["source_mask"],
source_ids,
source_mask,
num_beams=1,
max_length=80,
repetition_penalty=2.5,
@ -77,10 +71,7 @@ class BartSystem(BaseTransformer):
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for g in generated_ids
]
target = [
self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for t in batch["target_ids"]
]
target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
loss = self._step(batch)
return {"val_loss": loss, "preds": preds, "target": target}
@ -101,11 +92,21 @@ class BartSystem(BaseTransformer):
return self.test_end(outputs)
def train_dataloader(self):
train_dataset = SummarizationDataset(
self.tokenizer, data_dir=self.hparams.data_dir, type_path="train", block_size=self.hparams.max_seq_length
@property
def dataset_kwargs(self):
return dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
max_target_length=self.hparams.max_target_length,
)
dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size)
def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
// self.hparams.gradient_accumulation_steps
@ -117,29 +118,30 @@ class BartSystem(BaseTransformer):
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self):
val_dataset = SummarizationDataset(
self.tokenizer, data_dir=self.hparams.data_dir, type_path="val", block_size=self.hparams.max_seq_length
)
return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size)
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self):
test_dataset = SummarizationDataset(
self.tokenizer, data_dir=self.hparams.data_dir, type_path="test", block_size=self.hparams.max_seq_length
)
return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
# Add BART specific options
parser.add_argument(
"--max_seq_length",
"--max_source_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
@ -158,7 +160,7 @@ if __name__ == "__main__":
args = parser.parse_args()
# If output_dir not provided, a folder will be generated in pwd
if args.output_dir is None:
if not args.output_dir:
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
os.makedirs(args.output_dir)

View File

@ -5,28 +5,57 @@ import unittest
from pathlib import Path
from unittest.mock import patch
from torch.utils.data import DataLoader
from transformers import BartTokenizer
from .evaluate_cnn import run_generate
from .utils import SummarizationDataset
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def _dump_articles(path: Path, articles: list):
with path.open("w") as f:
f.write("\n".join(articles))
class TestBartExamples(unittest.TestCase):
def test_bart_cnn_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
with tmp.open("w") as f:
f.write("\n".join(articles))
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles)
testargs = ["evaluate_cnn.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path(output_file_name).exists())
self.assertTrue(output_file_name.exists())
def test_bart_summarization_dataset(self):
tmp_dir = Path(tempfile.gettempdir())
articles = [" Sam ate lunch today", "Sams lunch ingredients"]
summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
_dump_articles((tmp_dir / "train.source"), articles)
_dump_articles((tmp_dir / "train.target"), summaries)
tokenizer = BartTokenizer.from_pretrained("bart-large")
max_len_source = max(len(tokenizer.encode(a)) for a in articles)
max_len_target = max(len(tokenizer.encode(a)) for a in summaries)
trunc_target = 4
train_dataset = SummarizationDataset(
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
)
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
for batch in dataloader:
self.assertEqual(batch["source_mask"].shape, batch["source_ids"].shape)
# show that articles were trimmed.
self.assertEqual(batch["source_ids"].shape[1], max_len_source)
self.assertGreater(20, batch["source_ids"].shape[1]) # trimmed significantly
# show that targets were truncated
self.assertEqual(batch["target_ids"].shape[1], trunc_target) # Truncated
self.assertGreater(max_len_target, trunc_target) # Truncated

View File

@ -1,35 +1,35 @@
import os
import torch
from torch.utils.data import Dataset
from transformers.tokenization_utils import trim_batch
def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"):
examples = []
with open(data_path, "r") as f:
for text in f.readlines():
tokenized = tokenizer.batch_encode_plus(
[text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors,
)
examples.append(tokenized)
return examples
class SummarizationDataset(Dataset):
def __init__(self, tokenizer, data_dir="./cnn-dailymail/cnn_dm/", type_path="train", block_size=1024):
super(SummarizationDataset,).__init__()
def __init__(
self,
tokenizer,
data_dir="./cnn-dailymail/cnn_dm/",
type_path="train",
max_source_length=1024,
max_target_length=56,
):
super().__init__()
self.tokenizer = tokenizer
self.source = []
self.target = []
print("loading " + type_path + " source.")
with open(os.path.join(data_dir, type_path + ".source"), "r") as f:
for text in f.readlines(): # each text is a line and a full story
tokenized = tokenizer.batch_encode_plus(
[text], max_length=block_size, pad_to_max_length=True, return_tensors="pt"
)
self.source.append(tokenized)
f.close()
print("loading " + type_path + " target.")
with open(os.path.join(data_dir, type_path + ".target"), "r") as f:
for text in f.readlines(): # each text is a line and a summary
tokenized = tokenizer.batch_encode_plus(
[text], max_length=56, pad_to_max_length=True, return_tensors="pt"
)
self.target.append(tokenized)
f.close()
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
def __len__(self):
return len(self.source)
@ -37,7 +37,20 @@ class SummarizationDataset(Dataset):
def __getitem__(self, index):
source_ids = self.source[index]["input_ids"].squeeze()
target_ids = self.target[index]["input_ids"].squeeze()
src_mask = self.source[index]["attention_mask"].squeeze() # might need to squeeze
src_mask = self.source[index]["attention_mask"].squeeze()
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids}
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id):
y = trim_batch(batch["target_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch):
input_ids = torch.stack([x["source_ids"] for x in batch])
masks = torch.stack([x["source_mask"] for x in batch])
target_ids = torch.stack([x["target_ids"] for x in batch])
pad_token_id = self.tokenizer.pad_token_id
y = trim_batch(target_ids, pad_token_id)
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y}

View File

@ -47,27 +47,29 @@ def set_seed(args):
class BaseTransformer(pl.LightningModule):
def __init__(self, hparams, num_labels=None, mode="base"):
def __init__(self, hparams, num_labels=None, mode="base", **config_kwargs):
"Initialize a model."
super(BaseTransformer, self).__init__()
self.hparams = hparams
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
self.hparams.model_type = self.hparams.model_type.lower()
config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
cache_dir=cache_dir,
**config_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
do_lower_case=self.hparams.do_lower_case,
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
cache_dir=cache_dir,
)
model = MODEL_MODES[mode].from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=config,
cache_dir=self.hparams.cache_dir if self.hparams.cache_dir else None,
cache_dir=cache_dir,
)
self.config, self.tokenizer, self.model = config, tokenizer, model