Add generate kwargs to Seq2SeqTrainingArguments (#13339)

* Add generate kwargs to Seq2SeqTrainingArguments

* typo

* Address review comments + doc

* Style
This commit is contained in:
Sylvain Gugger 2021-08-31 08:42:00 -04:00 committed by GitHub
parent 702f4a49cd
commit c76de1053e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 23 deletions

View File

@ -556,12 +556,15 @@ def main():
# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
@ -572,10 +575,7 @@ def main():
logger.info("*** Predict ***")
predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
metrics = predict_results.metrics
max_predict_samples = (

View File

@ -549,12 +549,16 @@ def main():
# Evaluation
results = {}
max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
@ -565,10 +569,7 @@ def main():
logger.info("*** Predict ***")
predict_results = trainer.predict(
predict_dataset,
metric_key_prefix="predict",
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
)
metrics = predict_results.metrics
max_predict_samples = (

View File

@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer):
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer):
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
if max_length is not None or not hasattr(self, "_max_length"):
self._max_length = max_length
if num_beams is not None or not hasattr(self, "_num_beams"):
self._num_beams = num_beams
self._max_length = max_length if max_length is not None else self.args.generation_max_length
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(

View File

@ -14,6 +14,7 @@
import logging
from dataclasses import dataclass, field
from typing import Optional
from .file_utils import add_start_docstrings
from .training_args import TrainingArguments
@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
the training set.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
generation_max_length (:obj:`int`, `optional`):
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
the :obj:`max_length` value of the model configuration.
generation_num_beams (:obj:`int`, `optional`):
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
:obj:`num_beams` value of the model configuration.
"""
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
generation_max_length: Optional[int] = field(
default=None,
metadata={
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `max_length` value of the model configuration."
},
)
generation_num_beams: Optional[int] = field(
default=None,
metadata={
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
"to the `num_beams` value of the model configuration."
},
)