[trainer] add sanity evaluation option (#31146)

* add sanity evaluation

* fix

* Apply suggestions from code review

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* fix

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Marc Sun 2024-05-31 12:44:20 +02:00 committed by GitHub
parent fc5d3e112a
commit f8e6ba454c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 9 deletions

View File

@ -2175,6 +2175,9 @@ class Trainer:
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if args.sanity_evaluation:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
@ -2723,6 +2726,18 @@ class Trainer:
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)
def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
return metrics
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
@ -2749,15 +2764,7 @@ class Trainer:
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
metrics = self._evaluate(trial, ignore_keys_for_eval)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)

View File

@ -771,6 +771,9 @@ class TrainingArguments:
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
sanity_evaluation(`bool`, *optional*, defaults to `False`):
Whether or not to perform a sanity check to ensure that the validation steps works correctly. It will be performed before the training.
"""
framework = "pt"
@ -1454,6 +1457,13 @@ class TrainingArguments:
metadata={"help": "Break eval metrics calculation into batches to save memory."},
)
sanity_evaluation: bool = field(
default=False,
metadata={
"help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check."
},
)
def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS: