transformers/examples/summarization/finetune.py

332 lines
14 KiB
Python

import argparse
import glob
import logging
import os
import time
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
try:
from .utils import (
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
save_git_info,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
)
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
except ImportError:
from utils import (
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
save_git_info,
freeze_params,
calculate_rouge,
get_git_info,
ROUGE_KEYS,
)
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
logger = logging.getLogger(__name__)
class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
use_task_specific_params(self.model, "summarization")
save_git_info(self.hparams.output_dir)
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
self.step_count = 0
self.metrics = {"train": [], "val": [], "test": []}
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=self.model.config.prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if self.model.config.model_type == "bart":
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
y_ids = y[:, :-1].contiguous()
lm_labels = y[:, 1:].clone()
lm_labels[y[:, 1:] == pad_token_id] = -100
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
loss = outputs[0]
return (loss,)
def training_step(self, batch, batch_idx) -> Dict:
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
return {"loss": loss_tensors[0], "log": logs}
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
rouges.update({k: v.item() for k, v in losses.items()})
losses.update(rouges)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
preds = flatten_list([x["preds"] for x in outputs])
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
def save_metrics(self, metrics, prefix) -> None:
self.metrics[prefix].append(metrics)
pickle_save(self.metrics, self.metrics_save_path)
def _generative_step(self, batch):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# TODO(SS): task specific params
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
gen_time = time.time() - t0
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
rouge: Dict = calculate_rouge(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_end(self, outputs):
return self.validation_end(outputs, prefix="test")
def test_epoch_end(self, outputs):
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
# write predictions and targets for later rouge evaluation.
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
for output_batch in outputs:
p_writer.writelines(s + "\n" for s in output_batch["preds"])
t_writer.writelines(s + "\n" for s in output_batch["target"])
p_writer.close()
t_writer.close()
return self.test_end(outputs)
def validation_epoch_end(self, outputs):
self.validation_end(outputs, "val")
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
sampler=sampler,
)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps
* float(self.hparams.num_train_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=1024,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=56,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=142,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
return parser
def main(args, model=None) -> SummarizationModule:
Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if model is None:
model: BaseTransformer = SummarizationModule(args)
if (
args.logger == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name)
elif args.logger == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
logger=logger,
# TODO: early stopping callback seems messed up
)
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
main(args)