117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
import logging
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
from pytorch_lightning.utilities import rank_zero_only
|
|
|
|
from utils import save_json
|
|
|
|
|
|
def count_trainable_parameters(model):
|
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
|
params = sum([np.prod(p.size()) for p in model_parameters])
|
|
return params
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Seq2SeqLoggingCallback(pl.Callback):
|
|
def on_batch_end(self, trainer, pl_module):
|
|
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
|
pl_module.logger.log_metrics(lrs)
|
|
|
|
@rank_zero_only
|
|
def _write_logs(
|
|
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
|
) -> None:
|
|
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
|
metrics = trainer.callback_metrics
|
|
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
|
# Log results
|
|
od = Path(pl_module.hparams.output_dir)
|
|
if type_path == "test":
|
|
results_file = od / "test_results.txt"
|
|
generations_file = od / "test_generations.txt"
|
|
else:
|
|
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
|
# If people want this it will be easy enough to add back.
|
|
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
|
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
|
results_file.parent.mkdir(exist_ok=True)
|
|
generations_file.parent.mkdir(exist_ok=True)
|
|
with open(results_file, "a+") as writer:
|
|
for key in sorted(metrics):
|
|
if key in ["log", "progress_bar", "preds"]:
|
|
continue
|
|
val = metrics[key]
|
|
if isinstance(val, torch.Tensor):
|
|
val = val.item()
|
|
msg = f"{key}: {val:.6f}\n"
|
|
writer.write(msg)
|
|
|
|
if not save_generations:
|
|
return
|
|
|
|
if "preds" in metrics:
|
|
content = "\n".join(metrics["preds"])
|
|
generations_file.open("w+").write(content)
|
|
|
|
@rank_zero_only
|
|
def on_train_start(self, trainer, pl_module):
|
|
try:
|
|
npars = pl_module.model.model.num_parameters()
|
|
except AttributeError:
|
|
npars = pl_module.model.num_parameters()
|
|
|
|
n_trainable_pars = count_trainable_parameters(pl_module)
|
|
# mp stands for million parameters
|
|
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
|
|
|
@rank_zero_only
|
|
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
|
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
|
return self._write_logs(trainer, pl_module, "test")
|
|
|
|
@rank_zero_only
|
|
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
|
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
|
# Uncommenting this will save val generations
|
|
# return self._write_logs(trainer, pl_module, "valid")
|
|
|
|
|
|
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
|
"""Saves the best model by validation ROUGE2 score."""
|
|
if metric == "rouge2":
|
|
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
|
elif metric == "bleu":
|
|
exp = "{val_avg_bleu:.4f}-{step_count}"
|
|
elif metric == "loss":
|
|
exp = "{val_avg_loss:.4f}-{step_count}"
|
|
else:
|
|
raise NotImplementedError(
|
|
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to"
|
|
" this function."
|
|
)
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=output_dir,
|
|
filename=exp,
|
|
monitor=f"val_{metric}",
|
|
mode="min" if "loss" in metric else "max",
|
|
save_top_k=save_top_k,
|
|
)
|
|
return checkpoint_callback
|
|
|
|
|
|
def get_early_stopping_callback(metric, patience):
|
|
return EarlyStopping(
|
|
monitor=f"val_{metric}", # does this need avg?
|
|
mode="min" if "loss" in metric else "max",
|
|
patience=patience,
|
|
verbose=True,
|
|
)
|