[trainer] add sanity evaluation option (#31146)
* add sanity evaluation * fix * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * fix --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
parent
fc5d3e112a
commit
f8e6ba454c
|
@ -2175,6 +2175,9 @@ class Trainer:
|
|||
grad_norm: Optional[float] = None
|
||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||
|
||||
if args.sanity_evaluation:
|
||||
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
|
||||
|
||||
total_batched_samples = 0
|
||||
for epoch in range(epochs_trained, num_train_epochs):
|
||||
epoch_iterator = train_dataloader
|
||||
|
@ -2723,6 +2726,18 @@ class Trainer:
|
|||
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
|
||||
)
|
||||
|
||||
def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
|
||||
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
||||
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
||||
|
||||
# Run delayed LR scheduler now that metrics are populated
|
||||
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
self.lr_scheduler.step(metrics[metric_to_check])
|
||||
return metrics
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
|
||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||
if is_torch_xla_available():
|
||||
|
@ -2749,15 +2764,7 @@ class Trainer:
|
|||
|
||||
metrics = None
|
||||
if self.control.should_evaluate:
|
||||
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
||||
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
||||
|
||||
# Run delayed LR scheduler now that metrics are populated
|
||||
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
metric_to_check = self.args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
self.lr_scheduler.step(metrics[metric_to_check])
|
||||
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
||||
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial, metrics=metrics)
|
||||
|
|
|
@ -771,6 +771,9 @@ class TrainingArguments:
|
|||
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
|
||||
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
||||
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
||||
|
||||
sanity_evaluation(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to perform a sanity check to ensure that the validation steps works correctly. It will be performed before the training.
|
||||
"""
|
||||
|
||||
framework = "pt"
|
||||
|
@ -1454,6 +1457,13 @@ class TrainingArguments:
|
|||
metadata={"help": "Break eval metrics calculation into batches to save memory."},
|
||||
)
|
||||
|
||||
sanity_evaluation: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||
for field in _VALID_DICT_FIELDS:
|
||||
|
|
Loading…
Reference in New Issue