Small fixes to NotebookProgressCallback (#7813)
This commit is contained in:
parent
6f45dd2fac
commit
2ce3ddab2d
|
@ -153,7 +153,7 @@ try:
|
|||
import IPython # noqa: F401
|
||||
|
||||
_in_notebook = True
|
||||
except: # noqa: E722
|
||||
except (AttributeError, ImportError, KeyError):
|
||||
_in_notebook = False
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Optional
|
|||
import IPython.display as disp
|
||||
|
||||
from ..trainer_callback import TrainerCallback
|
||||
from ..trainer_utils import EvaluationStrategy
|
||||
|
||||
|
||||
def format_time(t):
|
||||
|
@ -146,7 +147,7 @@ class NotebookProgressBar:
|
|||
self.first_calls = self.warmup
|
||||
self.wait_for = 1
|
||||
self.update_bar(value)
|
||||
elif value <= self.last_value:
|
||||
elif value <= self.last_value and not force_update:
|
||||
return
|
||||
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
|
||||
if self.first_calls > 0:
|
||||
|
@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback):
|
|||
def __init__(self):
|
||||
self.training_tracker = None
|
||||
self.prediction_bar = None
|
||||
self._force_next_update = False
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self.first_column = "Epoch" if args.max_steps <= 0 else "Step"
|
||||
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
|
||||
self.training_loss = 0
|
||||
self.last_log = 0
|
||||
column_names = [self.first_column] + ["Training Loss", "Validation Loss"]
|
||||
column_names = [self.first_column] + ["Training Loss"]
|
||||
if args.evaluation_strategy != EvaluationStrategy.NO:
|
||||
column_names.append("Validation Loss")
|
||||
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
|
||||
self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}")
|
||||
self.training_tracker.update(
|
||||
state.global_step + 1,
|
||||
comment=f"Epoch {epoch}/{state.num_train_epochs}",
|
||||
force_update=self._force_next_update,
|
||||
)
|
||||
self._force_next_update = False
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if self.prediction_bar is None:
|
||||
|
@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback):
|
|||
else:
|
||||
self.prediction_bar.update(self.prediction_bar.value + 1)
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
# Only for when there is no evaluation
|
||||
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
|
||||
values = {"Training Loss": logs["loss"]}
|
||||
# First column is necessarily Step sine we're not in epoch eval strategy
|
||||
values["Step"] = state.global_step
|
||||
self.training_tracker.write_line(values)
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
if self.training_tracker is not None:
|
||||
values = {"Training Loss": "No log"}
|
||||
|
@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback):
|
|||
self.training_tracker.write_line(values)
|
||||
self.training_tracker.remove_child()
|
||||
self.prediction_bar = None
|
||||
# Evaluation takes a long time so we should force the next update.
|
||||
self._force_next_update = True
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
self.training_tracker.update(
|
||||
|
|
Loading…
Reference in New Issue