332 lines
14 KiB
Python
332 lines
14 KiB
Python
import argparse
|
|
import glob
|
|
import logging
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
|
from transformers import get_linear_schedule_with_warmup
|
|
|
|
|
|
try:
|
|
from .utils import (
|
|
use_task_specific_params,
|
|
SummarizationDataset,
|
|
lmap,
|
|
flatten_list,
|
|
pickle_save,
|
|
save_git_info,
|
|
freeze_params,
|
|
calculate_rouge,
|
|
get_git_info,
|
|
ROUGE_KEYS,
|
|
)
|
|
from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
|
except ImportError:
|
|
from utils import (
|
|
use_task_specific_params,
|
|
SummarizationDataset,
|
|
lmap,
|
|
flatten_list,
|
|
pickle_save,
|
|
save_git_info,
|
|
freeze_params,
|
|
calculate_rouge,
|
|
get_git_info,
|
|
ROUGE_KEYS,
|
|
)
|
|
from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SummarizationModule(BaseTransformer):
|
|
mode = "summarization"
|
|
loss_names = ["loss"]
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
|
use_task_specific_params(self.model, "summarization")
|
|
save_git_info(self.hparams.output_dir)
|
|
self.metrics_save_path = Path(self.output_dir) / "metrics.pkl"
|
|
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
|
self.step_count = 0
|
|
self.metrics = {"train": [], "val": [], "test": []}
|
|
|
|
self.dataset_kwargs: dict = dict(
|
|
data_dir=self.hparams.data_dir,
|
|
max_source_length=self.hparams.max_source_length,
|
|
prefix=self.model.config.prefix or "",
|
|
)
|
|
n_observations_per_split = {
|
|
"train": self.hparams.n_train,
|
|
"val": self.hparams.n_val,
|
|
"test": self.hparams.n_test,
|
|
}
|
|
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
|
|
|
self.target_lens = {
|
|
"train": self.hparams.max_target_length,
|
|
"val": self.hparams.val_max_target_length,
|
|
"test": self.hparams.test_max_target_length,
|
|
}
|
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
|
|
|
if self.hparams.freeze_embeds:
|
|
self.freeze_embeds()
|
|
if self.hparams.freeze_encoder:
|
|
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
|
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
|
|
|
|
def freeze_embeds(self):
|
|
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
|
if self.model.config.model_type == "bart":
|
|
freeze_params(self.model.model.shared)
|
|
for d in [self.model.model.encoder, self.model.model.decoder]:
|
|
freeze_params(d.embed_positions)
|
|
freeze_params(d.embed_tokens)
|
|
else:
|
|
freeze_params(self.model.shared)
|
|
for d in [self.model.encoder, self.model.decoder]:
|
|
freeze_params(d.embed_tokens)
|
|
|
|
def forward(self, input_ids, **kwargs):
|
|
return self.model(input_ids, **kwargs)
|
|
|
|
def ids_to_clean_text(self, generated_ids: List[int]):
|
|
gen_text = self.tokenizer.batch_decode(
|
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
return lmap(str.strip, gen_text)
|
|
|
|
def _step(self, batch: dict) -> Tuple:
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
|
y_ids = y[:, :-1].contiguous()
|
|
lm_labels = y[:, 1:].clone()
|
|
lm_labels[y[:, 1:] == pad_token_id] = -100
|
|
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, labels=lm_labels,)
|
|
loss = outputs[0]
|
|
return (loss,)
|
|
|
|
def training_step(self, batch, batch_idx) -> Dict:
|
|
loss_tensors = self._step(batch)
|
|
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
|
return {"loss": loss_tensors[0], "log": logs}
|
|
|
|
def validation_step(self, batch, batch_idx) -> Dict:
|
|
return self._generative_step(batch)
|
|
|
|
def validation_end(self, outputs, prefix="val") -> Dict:
|
|
self.step_count += 1
|
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
|
loss = losses["loss"]
|
|
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]}
|
|
rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss)
|
|
rouges.update({k: v.item() for k, v in losses.items()})
|
|
losses.update(rouges)
|
|
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
|
metrics["step_count"] = self.step_count
|
|
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
|
preds = flatten_list([x["preds"] for x in outputs])
|
|
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor}
|
|
|
|
def save_metrics(self, metrics, prefix) -> None:
|
|
self.metrics[prefix].append(metrics)
|
|
pickle_save(self.metrics, self.metrics_save_path)
|
|
|
|
def _generative_step(self, batch):
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
|
# TODO(SS): task specific params
|
|
|
|
t0 = time.time()
|
|
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
|
gen_time = time.time() - t0
|
|
preds = self.ids_to_clean_text(generated_ids)
|
|
target = self.ids_to_clean_text(y)
|
|
loss_tensors = self._step(batch)
|
|
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
|
rouge: Dict = calculate_rouge(preds, target)
|
|
summ_len = np.mean(lmap(len, generated_ids))
|
|
base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
|
|
return base_metrics
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
return self._generative_step(batch)
|
|
|
|
def test_end(self, outputs):
|
|
return self.validation_end(outputs, prefix="test")
|
|
|
|
def test_epoch_end(self, outputs):
|
|
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
|
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
|
|
# write predictions and targets for later rouge evaluation.
|
|
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
|
|
for output_batch in outputs:
|
|
p_writer.writelines(s + "\n" for s in output_batch["preds"])
|
|
t_writer.writelines(s + "\n" for s in output_batch["target"])
|
|
p_writer.close()
|
|
t_writer.close()
|
|
|
|
return self.test_end(outputs)
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
self.validation_end(outputs, "val")
|
|
|
|
def get_dataset(self, type_path) -> SummarizationDataset:
|
|
n_obs = self.n_obs[type_path]
|
|
max_target_length = self.target_lens[type_path]
|
|
dataset = SummarizationDataset(
|
|
self.tokenizer,
|
|
type_path=type_path,
|
|
n_obs=n_obs,
|
|
max_target_length=max_target_length,
|
|
**self.dataset_kwargs,
|
|
)
|
|
return dataset
|
|
|
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
|
dataset = self.get_dataset(type_path)
|
|
sampler = None
|
|
if self.hparams.sortish_sampler and type_path == "train":
|
|
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
|
sampler = dataset.make_sortish_sampler(batch_size)
|
|
shuffle = False
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=dataset.collate_fn,
|
|
shuffle=shuffle,
|
|
num_workers=self.num_workers,
|
|
sampler=sampler,
|
|
)
|
|
return dataloader
|
|
|
|
def train_dataloader(self) -> DataLoader:
|
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
|
t_total = (
|
|
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
|
// self.hparams.gradient_accumulation_steps
|
|
* float(self.hparams.num_train_epochs)
|
|
)
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
|
)
|
|
self.lr_scheduler = scheduler
|
|
return dataloader
|
|
|
|
def val_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
|
|
|
def test_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parser, root_dir):
|
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
|
add_generic_args(parser, root_dir)
|
|
parser.add_argument(
|
|
"--max_source_length",
|
|
default=1024,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_target_length",
|
|
default=56,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--val_max_target_length",
|
|
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--test_max_target_length",
|
|
default=142,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
type=str,
|
|
required=True,
|
|
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
|
|
)
|
|
parser.add_argument("--freeze_encoder", action="store_true")
|
|
parser.add_argument("--freeze_embeds", action="store_true")
|
|
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
|
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
return parser
|
|
|
|
|
|
def main(args, model=None) -> SummarizationModule:
|
|
Path(args.output_dir).mkdir(exist_ok=True)
|
|
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
|
if model is None:
|
|
model: BaseTransformer = SummarizationModule(args)
|
|
if (
|
|
args.logger == "default"
|
|
or args.fast_dev_run
|
|
or str(args.output_dir).startswith("/tmp")
|
|
or str(args.output_dir).startswith("/var")
|
|
):
|
|
logger = True # don't pollute wandb logs unnecessarily
|
|
elif args.logger == "wandb":
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
logger = WandbLogger(name=model.output_dir.name)
|
|
elif args.logger == "wandb_shared":
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
# TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB.
|
|
logger = WandbLogger(name=model.output_dir.name, project="hf_summarization")
|
|
trainer: pl.Trainer = generic_train(
|
|
model,
|
|
args,
|
|
logging_callback=Seq2SeqLoggingCallback(),
|
|
checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir),
|
|
logger=logger,
|
|
# TODO: early stopping callback seems messed up
|
|
)
|
|
if not args.do_predict:
|
|
return model
|
|
|
|
model.hparams.test_checkpoint = ""
|
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
|
if checkpoints:
|
|
model.hparams.test_checkpoint = checkpoints[-1]
|
|
trainer.resume_from_checkpoint = checkpoints[-1]
|
|
trainer.logger.log_hyperparams(model.hparams)
|
|
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|