Small fixes to NotebookProgressCallback (#7813)

This commit is contained in:
Sylvain Gugger 2020-10-15 10:30:34 -04:00 committed by GitHub
parent 6f45dd2fac
commit 2ce3ddab2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 5 deletions

View File

@ -153,7 +153,7 @@ try:
import IPython # noqa: F401
_in_notebook = True
except: # noqa: E722
except (AttributeError, ImportError, KeyError):
_in_notebook = False

View File

@ -19,6 +19,7 @@ from typing import Optional
import IPython.display as disp
from ..trainer_callback import TrainerCallback
from ..trainer_utils import EvaluationStrategy
def format_time(t):
@ -146,7 +147,7 @@ class NotebookProgressBar:
self.first_calls = self.warmup
self.wait_for = 1
self.update_bar(value)
elif value <= self.last_value:
elif value <= self.last_value and not force_update:
return
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
if self.first_calls > 0:
@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback):
def __init__(self):
self.training_tracker = None
self.prediction_bar = None
self._force_next_update = False
def on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.max_steps <= 0 else "Step"
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
self.training_loss = 0
self.last_log = 0
column_names = [self.first_column] + ["Training Loss", "Validation Loss"]
column_names = [self.first_column] + ["Training Loss"]
if args.evaluation_strategy != EvaluationStrategy.NO:
column_names.append("Validation Loss")
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
def on_step_end(self, args, state, control, **kwargs):
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}")
self.training_tracker.update(
state.global_step + 1,
comment=f"Epoch {epoch}/{state.num_train_epochs}",
force_update=self._force_next_update,
)
self._force_next_update = False
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if self.prediction_bar is None:
@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback):
else:
self.prediction_bar.update(self.prediction_bar.value + 1)
def on_log(self, args, state, control, logs=None, **kwargs):
# Only for when there is no evaluation
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
# First column is necessarily Step sine we're not in epoch eval strategy
values["Step"] = state.global_step
self.training_tracker.write_line(values)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None:
values = {"Training Loss": "No log"}
@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback):
self.training_tracker.write_line(values)
self.training_tracker.remove_child()
self.prediction_bar = None
# Evaluation takes a long time so we should force the next update.
self._force_next_update = True
def on_train_end(self, args, state, control, **kwargs):
self.training_tracker.update(