Add specific notebook ProgressCalback (#7793)

This commit is contained in:
Sylvain Gugger 2020-10-15 05:05:08 -04:00 committed by GitHub
parent 0911b6bd86
commit 62b5622e6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 352 additions and 2 deletions

View File

@ -142,6 +142,20 @@ try:
except (AttributeError, ImportError):
_has_sklearn = False
try:
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
import IPython # noqa: F401
_in_notebook = True
except (ImportError, KeyError):
_in_notebook = False
default_cache_path = os.path.join(torch_cache_home, "transformers")
@ -203,6 +217,10 @@ def is_faiss_available():
return _faiss_available
def is_in_notebook():
return _in_notebook
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:

View File

@ -34,7 +34,7 @@ from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
is_comet_available,
@ -89,7 +89,12 @@ _use_native_amp = False
_use_apex = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
if is_in_notebook():
from .utils.notebook import NotebookProgressCallback
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
@ -235,7 +240,7 @@ class Trainer:
)
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
# Deprecated arguments
if "tb_writer" in kwargs:

View File

@ -0,0 +1,327 @@
# coding=utf-8
# Copyright 2020 Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Optional
import IPython.display as disp
from ..trainer_callback import TrainerCallback
def format_time(t):
"Format `t` (in seconds) to (h):mm:ss"
t = int(t)
h, m, s = t // 3600, (t // 60) % 60, t % 60
return f"{h}:{m:02d}:{s:02d}" if h != 0 else f"{m:02d}:{s:02d}"
def html_progress_bar(value, total, prefix, label, width=300):
"Html code for a progress bar `value`/`total` with `label` on the right, `prefix` on the left."
return f"""
<div>
<style>
/* Turns off some styling */
progress {{
/* gets rid of default border in Firefox and Opera. */
border: none;
/* Needs to be in here for Safari polyfill so background images work as expected. */
background-size: auto;
}}
</style>
{prefix}
<progress value='{value}' max='{total}' style='width:{width}px; height:20px; vertical-align: middle;'></progress>
{label}
</div>
"""
def text_to_html_table(items):
"Put the texts in `items` in an HTML table."
html_code = """<table border="1" class="dataframe">\n"""
html_code += """ <thead>\n <tr style="text-align: left;">\n"""
for i in items[0]:
html_code += f" <th>{i}</th>\n"
html_code += " </tr>\n </thead>\n <tbody>\n"
for line in items[1:]:
html_code += " <tr>\n"
for elt in line:
elt = f"{elt:.6f}" if isinstance(elt, float) else str(elt)
html_code += f" <td>{elt}</td>\n"
html_code += " </tr>\n"
html_code += " </tbody>\n</table><p>"
return html_code
class NotebookProgressBar:
"""
A progress par for display in a notebook.
Class attributes (overridden by derived classes)
- **warmup** (:obj:`int`) -- The number of iterations to do at the beginning while ignoring
:obj:`update_every`.
- **update_every** (:obj:`float`) -- Since calling the time takes some time, we only do it
every presumed :obj:`update_every` seconds. The progress bar uses the average time passed
up until now to guess the next value for which it will call the update.
Args:
total (:obj:`int`):
The total number of iterations to reach.
prefix (:obj:`str`, `optional`):
A prefix to add before the progress bar.
leave (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to leave the progress bar once it's completed. You can always call the
:meth:`~transformers.utils.notebook.NotebookProgressBar.close` method to make the bar disappear.
parent (:class:`~transformers.notebook.NotebookTrainingTracker`, `optional`):
A parent object (like :class:`~transformers.utils.notebook.NotebookTrainingTracker`) that spawns progress
bars and handle their display. If set, the object passed must have a :obj:`display()` method.
width (:obj:`int`, `optional`, defaults to 300):
The width (in pixels) that the bar will take.
Example::
import time
pbar = NotebookProgressBar(100)
for val in range(100):
pbar.update(val)
time.sleep(0.07)
pbar.update(100)
"""
warmup = 5
update_every = 0.2
def __init__(
self,
total: int,
prefix: Optional[str] = None,
leave: bool = True,
parent: Optional["NotebookTrainingTracker"] = None,
width: int = 300,
):
self.total = total
self.prefix = "" if prefix is None else prefix
self.leave = leave
self.parent = parent
self.width = width
self.last_value = None
self.comment = None
self.output = None
def update(self, value: int, force_update: bool = False, comment: str = None):
"""
The main method to update the progress bar to :obj:`value`.
Args:
value (:obj:`int`):
The value to use. Must be between 0 and :obj:`total`.
force_update (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force and update of the internal state and display (by default, the bar will wait for
:obj:`value` to reach the value it predicted corresponds to a time of more than the :obj:`update_every`
attribute since the last update to avoid adding boilerplate).
comment (:obj:`str`, `optional`):
A comment to add on the left of the progress bar.
"""
self.value = value
if comment is not None:
self.comment = comment
if self.last_value is None:
self.start_time = self.last_time = time.time()
self.start_value = self.last_value = value
self.elapsed_time = self.predicted_remaining = None
self.first_calls = self.warmup
self.wait_for = 1
self.update_bar(value)
elif value <= self.last_value:
return
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
if self.first_calls > 0:
self.first_calls -= 1
current_time = time.time()
self.elapsed_time = current_time - self.start_time
self.average_time_per_item = self.elapsed_time / (value - self.start_value)
if value >= self.total:
value = self.total
self.predicted_remaining = None
if not self.leave:
self.close()
else:
self.predicted_remaining = self.average_time_per_item * (self.total - value)
self.update_bar(value)
self.last_value = value
self.last_time = current_time
self.wait_for = max(int(self.update_every / self.average_time_per_item), 1)
def update_bar(self, value, comment=None):
spaced_value = " " * (len(str(self.total)) - len(str(value))) + str(value)
if self.elapsed_time is None:
self.label = f"[{spaced_value}/{self.total} : < :"
elif self.predicted_remaining is None:
self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)}"
else:
self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} < {format_time(self.predicted_remaining)}"
self.label += f", {1/self.average_time_per_item:.2f} it/s"
self.label += "]" if self.comment is None or len(self.comment) == 0 else f", {self.comment}]"
self.display()
def display(self):
self.html_code = html_progress_bar(self.value, self.total, self.prefix, self.label, self.width)
if self.parent is not None:
# If this is a child bar, the parent will take care of the display.
self.parent.display()
return
if self.output is None:
self.output = disp.display(disp.HTML(self.html_code), display_id=True)
else:
self.output.update(disp.HTML(self.html_code))
def close(self):
"Closes the progress bar."
if self.parent is None and self.output is not None:
self.output.update(disp.HTML(""))
class NotebookTrainingTracker(NotebookProgressBar):
"""
An object tracking the updates of an ongoing training with progress bars and a nice table reporting metrics.
Args:
num_steps (:obj:`int`): The number of steps during training.
column_names (:obj:`List[str]`, `optional`):
The list of column names for the metrics table (will be infered from the first call to
:meth:`~transformers.utils.notebook.NotebookTrainingTracker.write_line` if not set).
"""
def __init__(self, num_steps, column_names=None):
super().__init__(num_steps)
self.inner_table = None if column_names is None else [column_names]
self.child_bar = None
def display(self):
self.html_code = html_progress_bar(self.value, self.total, self.prefix, self.label, self.width)
if self.inner_table is not None:
self.html_code += text_to_html_table(self.inner_table)
if self.child_bar is not None:
self.html_code += self.child_bar.html_code
if self.output is None:
self.output = disp.display(disp.HTML(self.html_code), display_id=True)
else:
self.output.update(disp.HTML(self.html_code))
def write_line(self, values):
"""
Write the values in the inner table.
Args:
values (:obj:`Dict[str, float]`): The values to display.
"""
if self.inner_table is None:
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
if len(self.inner_table) == 1:
# We give a chance to update the column names at the first iteration
for key in values.keys():
if key not in columns:
columns.append(key)
self.inner_table[0] = columns
self.inner_table.append([values[c] for c in columns])
def add_child(self, total, prefix=None, width=300):
"""
Add a child progress bar disaplyed under the table of metrics. The child progress bar is returned (so it can
be easily updated).
Args:
total (:obj:`int`): The number of iterations for the child progress bar.
prefix (:obj:`str`, `optional`): A prefix to write on the left of the progress bar.
width (:obj:`int`, `optional`, defaults to 300): The width (in pixels) of the progress bar.
"""
self.child_bar = NotebookProgressBar(total, prefix=prefix, parent=self, width=width)
return self.child_bar
def remove_child(self):
"""
Closes the child progress bar.
"""
self.child_bar = None
self.display()
class NotebookProgressCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation, optimized for
Jupyter Notebooks or Google colab.
"""
def __init__(self):
self.training_tracker = None
self.prediction_bar = None
def on_train_begin(self, args, state, control, **kwargs):
self.first_column = "Epoch" if args.max_steps <= 0 else "Step"
self.training_loss = 0
self.last_log = 0
column_names = [self.first_column] + ["Training Loss", "Validation Loss"]
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
def on_step_end(self, args, state, control, **kwargs):
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}")
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if self.prediction_bar is None:
if self.training_tracker is not None:
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
else:
self.prediction_bar = NotebookProgressBar(len(eval_dataloader))
self.prediction_bar.update(1)
else:
self.prediction_bar.update(self.prediction_bar.value + 1)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if self.training_tracker is not None:
values = {"Training Loss": "No log"}
for log in reversed(state.log_history):
if "loss" in log:
values["Training Loss"] = log["loss"]
break
if self.first_column == "Epoch":
values["Epoch"] = int(state.epoch)
else:
values["Step"] = state.global_step
values["Validation Loss"] = metrics["eval_loss"]
_ = metrics.pop("total_flos", None)
_ = metrics.pop("epoch", None)
for k, v in metrics.items():
if k == "eval_loss":
values["Validation Loss"] = v
else:
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits[1:]])
values[name] = v
self.training_tracker.write_line(values)
self.training_tracker.remove_child()
self.prediction_bar = None
def on_train_end(self, args, state, control, **kwargs):
self.training_tracker.update(
state.global_step, comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", force_update=True
)
self.training_tracker = None