Add generate kwargs to Seq2SeqTrainingArguments (#13339)
* Add generate kwargs to Seq2SeqTrainingArguments * typo * Address review comments + doc * Style
This commit is contained in:
parent
702f4a49cd
commit
c76de1053e
|
@ -556,12 +556,15 @@ def main():
|
|||
|
||||
# Evaluation
|
||||
results = {}
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
|
@ -572,10 +575,7 @@ def main():
|
|||
logger.info("*** Predict ***")
|
||||
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
|
||||
)
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
|
|
|
@ -549,12 +549,16 @@ def main():
|
|||
|
||||
# Evaluation
|
||||
results = {}
|
||||
max_length = (
|
||||
training_args.generation_max_length
|
||||
if training_args.generation_max_length is not None
|
||||
else data_args.val_max_target_length
|
||||
)
|
||||
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
|
@ -565,10 +569,7 @@ def main():
|
|||
logger.info("*** Predict ***")
|
||||
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
|
||||
)
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
|
|
|
@ -70,10 +70,8 @@ class Seq2SeqTrainer(Trainer):
|
|||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
||||
dictionary also contains the epoch number which comes from the training state.
|
||||
"""
|
||||
if max_length is not None or not hasattr(self, "_max_length"):
|
||||
self._max_length = max_length
|
||||
if num_beams is not None or not hasattr(self, "_num_beams"):
|
||||
self._num_beams = num_beams
|
||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
||||
def predict(
|
||||
|
@ -119,10 +117,8 @@ class Seq2SeqTrainer(Trainer):
|
|||
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
||||
contained labels).
|
||||
"""
|
||||
if max_length is not None or not hasattr(self, "_max_length"):
|
||||
self._max_length = max_length
|
||||
if num_beams is not None or not hasattr(self, "_num_beams"):
|
||||
self._num_beams = num_beams
|
||||
self._max_length = max_length if max_length is not None else self.args.generation_max_length
|
||||
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
|
||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
||||
def prediction_step(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .training_args import TrainingArguments
|
||||
|
@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
|
|||
the training set.
|
||||
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
|
||||
generation_max_length (:obj:`int`, `optional`):
|
||||
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
|
||||
the :obj:`max_length` value of the model configuration.
|
||||
generation_num_beams (:obj:`int`, `optional`):
|
||||
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
|
||||
:obj:`num_beams` value of the model configuration.
|
||||
"""
|
||||
|
||||
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
|
||||
predict_with_generate: bool = field(
|
||||
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||
)
|
||||
generation_max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||||
"to the `max_length` value of the model configuration."
|
||||
},
|
||||
)
|
||||
generation_num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||||
"to the `num_beams` value of the model configuration."
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue