From c76de1053e76010340a3cf152e51d4d9f5a1f755 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 31 Aug 2021 08:42:00 -0400 Subject: [PATCH] Add generate kwargs to Seq2SeqTrainingArguments (#13339) * Add generate kwargs to Seq2SeqTrainingArguments * typo * Address review comments + doc * Style --- .../summarization/run_summarization.py | 16 +++++++------- .../pytorch/translation/run_translation.py | 15 ++++++------- src/transformers/trainer_seq2seq.py | 12 ++++------- src/transformers/training_args_seq2seq.py | 21 +++++++++++++++++++ 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 73664be799..ce2deec706 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -556,12 +556,15 @@ def main(): # Evaluation results = {} + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams if training_args.do_eval: logger.info("*** Evaluate ***") - - metrics = trainer.evaluate( - max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval" - ) + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) @@ -572,10 +575,7 @@ def main(): logger.info("*** Predict ***") predict_results = trainer.predict( - predict_dataset, - metric_key_prefix="predict", - max_length=data_args.val_max_target_length, - num_beams=data_args.num_beams, + predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams ) metrics = predict_results.metrics max_predict_samples = ( diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index c8c58ac19b..d00d3f96bf 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -549,12 +549,16 @@ def main(): # Evaluation results = {} + max_length = ( + training_args.generation_max_length + if training_args.generation_max_length is not None + else data_args.val_max_target_length + ) + num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate( - max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval" - ) + metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) @@ -565,10 +569,7 @@ def main(): logger.info("*** Predict ***") predict_results = trainer.predict( - predict_dataset, - metric_key_prefix="predict", - max_length=data_args.val_max_target_length, - num_beams=data_args.num_beams, + predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams ) metrics = predict_results.metrics max_predict_samples = ( diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index bfcded021f..1995677801 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer): A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ - if max_length is not None or not hasattr(self, "_max_length"): - self._max_length = max_length - if num_beams is not None or not hasattr(self, "_num_beams"): - self._num_beams = num_beams + self._max_length = max_length if max_length is not None else self.args.generation_max_length + self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) def predict( @@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer): - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset contained labels). """ - if max_length is not None or not hasattr(self, "_max_length"): - self._max_length = max_length - if num_beams is not None or not hasattr(self, "_num_beams"): - self._num_beams = num_beams + self._max_length = max_length if max_length is not None else self.args.generation_max_length + self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) def prediction_step( diff --git a/src/transformers/training_args_seq2seq.py b/src/transformers/training_args_seq2seq.py index 8527fda1fd..02b9a77be0 100644 --- a/src/transformers/training_args_seq2seq.py +++ b/src/transformers/training_args_seq2seq.py @@ -14,6 +14,7 @@ import logging from dataclasses import dataclass, field +from typing import Optional from .file_utils import add_start_docstrings from .training_args import TrainingArguments @@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments): the training set. predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to use generate to calculate generative metrics (ROUGE, BLEU). + generation_max_length (:obj:`int`, `optional`): + The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to + the :obj:`max_length` value of the model configuration. + generation_num_beams (:obj:`int`, `optional`): + The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the + :obj:`num_beams` value of the model configuration. """ sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."}) predict_with_generate: bool = field( default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."} ) + generation_max_length: Optional[int] = field( + default=None, + metadata={ + "help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `max_length` value of the model configuration." + }, + ) + generation_num_beams: Optional[int] = field( + default=None, + metadata={ + "help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default " + "to the `num_beams` value of the model configuration." + }, + )