From c19d04623eacfbc2c452397a5eda0fde42db3fc5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 14 Dec 2020 17:45:33 -0800 Subject: [PATCH] [finetune_trainer] enhancements and fixes (#9042) * trainer and finetune_trainer enhancements and fixes * add fallback default * move the fixing of incorrect keys back into finetune trainer * s/eval/val/ to match the split * trainer can now use a different prefix than eval_ for metrics * document new arg * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * use 'eval' as the default for metric_key_prefix * complete adjust var names + disambiguate * fix logger * add clarifying comment * add clarifying comment * style * Apply suggestions from code review Co-authored-by: Patrick von Platen * Update src/transformers/trainer.py Co-authored-by: Patrick von Platen * complete removal of optional for metric_key_prefix * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen --- examples/seq2seq/finetune_trainer.py | 95 ++++++++++++++++++++++------ examples/seq2seq/utils.py | 2 +- src/transformers/trainer.py | 29 +++++++-- 3 files changed, 97 insertions(+), 29 deletions(-) diff --git a/examples/seq2seq/finetune_trainer.py b/examples/seq2seq/finetune_trainer.py index 22ec2d7ae3..578e32045c 100755 --- a/examples/seq2seq/finetune_trainer.py +++ b/examples/seq2seq/finetune_trainer.py @@ -16,6 +16,7 @@ import logging import os import sys +import time from dataclasses import dataclass, field from typing import Optional @@ -119,6 +120,46 @@ class DataTrainingArguments: ) +def speed_metrics(split, start_time, num_samples): + """ + Measure and return speed performance metrics. + + This function requires a time snapshot `start_time` before the operation to be measured starts and this + function should be run immediately after the operation to be measured has completed. + + Args: + - split: one of train, val, test + - start_time: operation start time + - num_samples: number of samples processed + + """ + runtime = time.time() - start_time + result = {} + + samples_per_second = 1 / (runtime / num_samples) + result[f"{split}_samples_per_second"] = round(samples_per_second, 3) + result[f"{split}_runtime"] = round(runtime, 4) + + result[f"{split}_n_ojbs"] = num_samples + return result + + +def handle_metrics(split, metrics, output_dir): + """ + Log and save metrics + + Args: + - split: one of train, val, test + - metrics: metrics dict + - output_dir: where to save the metrics + """ + + logger.info(f"***** {split} metrics *****") + for key, value in metrics.items(): + logger.info(f" {key} = {value}") + save_json(metrics, os.path.join(output_dir, f"{split}_results.json")) + + def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -265,45 +306,56 @@ def main(): data_args=data_args, ) + all_metrics = {} # Training if training_args.do_train: + logger.info("*** Train ***") + + start_time = time.time() trainer.train( model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) - trainer.save_model() - # For convenience, we also re-save the tokenizer to the same directory, - # so that you can share your model easily on huggingface.co/models =) + metrics = speed_metrics("train", start_time, data_args.n_train) + + trainer.save_model() # this also saves the tokenizer + if trainer.is_world_process_zero(): + handle_metrics("train", metrics, training_args.output_dir) + all_metrics.update(metrics) + + # Need to save the state, since Trainer.save_model saves only the tokenizer with the model trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) + + # For convenience, we also re-save the tokenizer to the same directory, + # so that you can share your model easily on huggingface.co/models =) tokenizer.save_pretrained(training_args.output_dir) # Evaluation - eval_results = {} if training_args.do_eval: logger.info("*** Evaluate ***") - result = trainer.evaluate() + start_time = time.time() + metrics = trainer.evaluate(metric_key_prefix="val") + metrics.update(speed_metrics("val", start_time, data_args.n_val)) + metrics["val_loss"] = round(metrics["val_loss"], 4) if trainer.is_world_process_zero(): - logger.info("***** Eval results *****") - for key, value in result.items(): - logger.info(" %s = %s", key, value) - save_json(result, os.path.join(training_args.output_dir, "eval_results.json")) - eval_results.update(result) + + handle_metrics("val", metrics, training_args.output_dir) + all_metrics.update(metrics) if training_args.do_predict: - logging.info("*** Test ***") + logger.info("*** Predict ***") - test_output = trainer.predict(test_dataset=test_dataset) - test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()} + start_time = time.time() + test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test") + metrics = test_output.metrics + metrics.update(speed_metrics("test", start_time, data_args.n_test)) if trainer.is_world_process_zero(): - logger.info("***** Test results *****") - for key, value in test_metrics.items(): - logger.info(" %s = %s", key, value) - - save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json")) - eval_results.update(test_metrics) + metrics["test_loss"] = round(metrics["test_loss"], 4) + handle_metrics("test", metrics, training_args.output_dir) + all_metrics.update(metrics) if training_args.predict_with_generate: test_preds = tokenizer.batch_decode( @@ -313,8 +365,9 @@ def main(): write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) if trainer.is_world_process_zero(): - save_json(eval_results, "all_results.json") - return eval_results + save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json")) + + return all_metrics def _mp_fn(index): diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 70ef5f07ba..8014463122 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -462,7 +462,7 @@ def save_git_info(folder_path: str) -> None: def save_json(content, path, indent=4, **json_dump_kwargs): with open(path, "w") as f: - json.dump(content, f, indent=indent, **json_dump_kwargs) + json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs) def load_json(path): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8fbe1729fa..41f36917d9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1243,7 +1243,10 @@ class Trainer: shutil.rmtree(checkpoint) def evaluate( - self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", ) -> Dict[str, float]: """ Run evaluation and returns metrics. @@ -1261,6 +1264,9 @@ class Trainer: ignore_keys (:obj:`Lst[str]`, `optional`): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. + metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The @@ -1278,6 +1284,7 @@ class Trainer: # self.args.prediction_loss_only prediction_loss_only=True if self.compute_metrics is None else None, ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, ) self.log(output.metrics) @@ -1289,7 +1296,9 @@ class Trainer: self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) return output.metrics - def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput: + def predict( + self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval" + ) -> PredictionOutput: """ Run prediction and returns predictions and potential metrics. @@ -1303,6 +1312,9 @@ class Trainer: ignore_keys (:obj:`Lst[str]`, `optional`): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. + metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) .. note:: @@ -1322,7 +1334,9 @@ class Trainer: test_dataloader = self.get_test_dataloader(test_dataset) - return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys) + return self.prediction_loop( + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) def prediction_loop( self, @@ -1330,6 +1344,7 @@ class Trainer: description: str, prediction_loss_only: Optional[bool] = None, ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", ) -> PredictionOutput: """ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. @@ -1421,12 +1436,12 @@ class Trainer: metrics = {} if eval_loss is not None: - metrics["eval_loss"] = eval_loss.mean().item() + metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() - # Prefix all keys with eval_ + # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): - if not key.startswith("eval_"): - metrics[f"eval_{key}"] = metrics.pop(key) + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)