Add specific notebook ProgressCalback (#7793)
This commit is contained in:
parent
0911b6bd86
commit
62b5622e6b
|
@ -142,6 +142,20 @@ try:
|
|||
except (AttributeError, ImportError):
|
||||
_has_sklearn = False
|
||||
|
||||
try:
|
||||
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
get_ipython = sys.modules["IPython"].get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config:
|
||||
raise ImportError("console")
|
||||
if "VSCODE_PID" in os.environ:
|
||||
raise ImportError("vscode")
|
||||
|
||||
import IPython # noqa: F401
|
||||
|
||||
_in_notebook = True
|
||||
except (ImportError, KeyError):
|
||||
_in_notebook = False
|
||||
|
||||
|
||||
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
|
||||
|
@ -203,6 +217,10 @@ def is_faiss_available():
|
|||
return _faiss_available
|
||||
|
||||
|
||||
def is_in_notebook():
|
||||
return _in_notebook
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
|
|
|
@ -34,7 +34,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
|
||||
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
|
||||
from .integrations import (
|
||||
default_hp_search_backend,
|
||||
is_comet_available,
|
||||
|
@ -89,7 +89,12 @@ _use_native_amp = False
|
|||
_use_apex = False
|
||||
|
||||
DEFAULT_CALLBACKS = [DefaultFlowCallback]
|
||||
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
|
||||
|
||||
if is_in_notebook():
|
||||
from .utils.notebook import NotebookProgressCallback
|
||||
|
||||
DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
|
||||
|
||||
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
||||
if version.parse(torch.__version__) < version.parse("1.6"):
|
||||
|
@ -235,7 +240,7 @@ class Trainer:
|
|||
)
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
|
||||
# Deprecated arguments
|
||||
if "tb_writer" in kwargs:
|
||||
|
|
|
@ -0,0 +1,327 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import IPython.display as disp
|
||||
|
||||
from ..trainer_callback import TrainerCallback
|
||||
|
||||
|
||||
def format_time(t):
|
||||
"Format `t` (in seconds) to (h):mm:ss"
|
||||
t = int(t)
|
||||
h, m, s = t // 3600, (t // 60) % 60, t % 60
|
||||
return f"{h}:{m:02d}:{s:02d}" if h != 0 else f"{m:02d}:{s:02d}"
|
||||
|
||||
|
||||
def html_progress_bar(value, total, prefix, label, width=300):
|
||||
"Html code for a progress bar `value`/`total` with `label` on the right, `prefix` on the left."
|
||||
return f"""
|
||||
<div>
|
||||
<style>
|
||||
/* Turns off some styling */
|
||||
progress {{
|
||||
/* gets rid of default border in Firefox and Opera. */
|
||||
border: none;
|
||||
/* Needs to be in here for Safari polyfill so background images work as expected. */
|
||||
background-size: auto;
|
||||
}}
|
||||
</style>
|
||||
{prefix}
|
||||
<progress value='{value}' max='{total}' style='width:{width}px; height:20px; vertical-align: middle;'></progress>
|
||||
{label}
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def text_to_html_table(items):
|
||||
"Put the texts in `items` in an HTML table."
|
||||
html_code = """<table border="1" class="dataframe">\n"""
|
||||
html_code += """ <thead>\n <tr style="text-align: left;">\n"""
|
||||
for i in items[0]:
|
||||
html_code += f" <th>{i}</th>\n"
|
||||
html_code += " </tr>\n </thead>\n <tbody>\n"
|
||||
for line in items[1:]:
|
||||
html_code += " <tr>\n"
|
||||
for elt in line:
|
||||
elt = f"{elt:.6f}" if isinstance(elt, float) else str(elt)
|
||||
html_code += f" <td>{elt}</td>\n"
|
||||
html_code += " </tr>\n"
|
||||
html_code += " </tbody>\n</table><p>"
|
||||
return html_code
|
||||
|
||||
|
||||
class NotebookProgressBar:
|
||||
"""
|
||||
A progress par for display in a notebook.
|
||||
|
||||
Class attributes (overridden by derived classes)
|
||||
- **warmup** (:obj:`int`) -- The number of iterations to do at the beginning while ignoring
|
||||
:obj:`update_every`.
|
||||
- **update_every** (:obj:`float`) -- Since calling the time takes some time, we only do it
|
||||
every presumed :obj:`update_every` seconds. The progress bar uses the average time passed
|
||||
up until now to guess the next value for which it will call the update.
|
||||
|
||||
Args:
|
||||
total (:obj:`int`):
|
||||
The total number of iterations to reach.
|
||||
prefix (:obj:`str`, `optional`):
|
||||
A prefix to add before the progress bar.
|
||||
leave (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to leave the progress bar once it's completed. You can always call the
|
||||
:meth:`~transformers.utils.notebook.NotebookProgressBar.close` method to make the bar disappear.
|
||||
parent (:class:`~transformers.notebook.NotebookTrainingTracker`, `optional`):
|
||||
A parent object (like :class:`~transformers.utils.notebook.NotebookTrainingTracker`) that spawns progress
|
||||
bars and handle their display. If set, the object passed must have a :obj:`display()` method.
|
||||
width (:obj:`int`, `optional`, defaults to 300):
|
||||
The width (in pixels) that the bar will take.
|
||||
|
||||
Example::
|
||||
|
||||
import time
|
||||
|
||||
pbar = NotebookProgressBar(100)
|
||||
for val in range(100):
|
||||
pbar.update(val)
|
||||
time.sleep(0.07)
|
||||
pbar.update(100)
|
||||
"""
|
||||
|
||||
warmup = 5
|
||||
update_every = 0.2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
total: int,
|
||||
prefix: Optional[str] = None,
|
||||
leave: bool = True,
|
||||
parent: Optional["NotebookTrainingTracker"] = None,
|
||||
width: int = 300,
|
||||
):
|
||||
self.total = total
|
||||
self.prefix = "" if prefix is None else prefix
|
||||
self.leave = leave
|
||||
self.parent = parent
|
||||
self.width = width
|
||||
self.last_value = None
|
||||
self.comment = None
|
||||
self.output = None
|
||||
|
||||
def update(self, value: int, force_update: bool = False, comment: str = None):
|
||||
"""
|
||||
The main method to update the progress bar to :obj:`value`.
|
||||
|
||||
Args:
|
||||
|
||||
value (:obj:`int`):
|
||||
The value to use. Must be between 0 and :obj:`total`.
|
||||
force_update (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force and update of the internal state and display (by default, the bar will wait for
|
||||
:obj:`value` to reach the value it predicted corresponds to a time of more than the :obj:`update_every`
|
||||
attribute since the last update to avoid adding boilerplate).
|
||||
comment (:obj:`str`, `optional`):
|
||||
A comment to add on the left of the progress bar.
|
||||
"""
|
||||
self.value = value
|
||||
if comment is not None:
|
||||
self.comment = comment
|
||||
if self.last_value is None:
|
||||
self.start_time = self.last_time = time.time()
|
||||
self.start_value = self.last_value = value
|
||||
self.elapsed_time = self.predicted_remaining = None
|
||||
self.first_calls = self.warmup
|
||||
self.wait_for = 1
|
||||
self.update_bar(value)
|
||||
elif value <= self.last_value:
|
||||
return
|
||||
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
|
||||
if self.first_calls > 0:
|
||||
self.first_calls -= 1
|
||||
current_time = time.time()
|
||||
self.elapsed_time = current_time - self.start_time
|
||||
self.average_time_per_item = self.elapsed_time / (value - self.start_value)
|
||||
if value >= self.total:
|
||||
value = self.total
|
||||
self.predicted_remaining = None
|
||||
if not self.leave:
|
||||
self.close()
|
||||
else:
|
||||
self.predicted_remaining = self.average_time_per_item * (self.total - value)
|
||||
self.update_bar(value)
|
||||
self.last_value = value
|
||||
self.last_time = current_time
|
||||
self.wait_for = max(int(self.update_every / self.average_time_per_item), 1)
|
||||
|
||||
def update_bar(self, value, comment=None):
|
||||
spaced_value = " " * (len(str(self.total)) - len(str(value))) + str(value)
|
||||
if self.elapsed_time is None:
|
||||
self.label = f"[{spaced_value}/{self.total} : < :"
|
||||
elif self.predicted_remaining is None:
|
||||
self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)}"
|
||||
else:
|
||||
self.label = f"[{spaced_value}/{self.total} {format_time(self.elapsed_time)} < {format_time(self.predicted_remaining)}"
|
||||
self.label += f", {1/self.average_time_per_item:.2f} it/s"
|
||||
self.label += "]" if self.comment is None or len(self.comment) == 0 else f", {self.comment}]"
|
||||
self.display()
|
||||
|
||||
def display(self):
|
||||
self.html_code = html_progress_bar(self.value, self.total, self.prefix, self.label, self.width)
|
||||
if self.parent is not None:
|
||||
# If this is a child bar, the parent will take care of the display.
|
||||
self.parent.display()
|
||||
return
|
||||
if self.output is None:
|
||||
self.output = disp.display(disp.HTML(self.html_code), display_id=True)
|
||||
else:
|
||||
self.output.update(disp.HTML(self.html_code))
|
||||
|
||||
def close(self):
|
||||
"Closes the progress bar."
|
||||
if self.parent is None and self.output is not None:
|
||||
self.output.update(disp.HTML(""))
|
||||
|
||||
|
||||
class NotebookTrainingTracker(NotebookProgressBar):
|
||||
"""
|
||||
An object tracking the updates of an ongoing training with progress bars and a nice table reporting metrics.
|
||||
|
||||
Args:
|
||||
|
||||
num_steps (:obj:`int`): The number of steps during training.
|
||||
column_names (:obj:`List[str]`, `optional`):
|
||||
The list of column names for the metrics table (will be infered from the first call to
|
||||
:meth:`~transformers.utils.notebook.NotebookTrainingTracker.write_line` if not set).
|
||||
"""
|
||||
|
||||
def __init__(self, num_steps, column_names=None):
|
||||
super().__init__(num_steps)
|
||||
self.inner_table = None if column_names is None else [column_names]
|
||||
self.child_bar = None
|
||||
|
||||
def display(self):
|
||||
self.html_code = html_progress_bar(self.value, self.total, self.prefix, self.label, self.width)
|
||||
if self.inner_table is not None:
|
||||
self.html_code += text_to_html_table(self.inner_table)
|
||||
if self.child_bar is not None:
|
||||
self.html_code += self.child_bar.html_code
|
||||
if self.output is None:
|
||||
self.output = disp.display(disp.HTML(self.html_code), display_id=True)
|
||||
else:
|
||||
self.output.update(disp.HTML(self.html_code))
|
||||
|
||||
def write_line(self, values):
|
||||
"""
|
||||
Write the values in the inner table.
|
||||
|
||||
Args:
|
||||
values (:obj:`Dict[str, float]`): The values to display.
|
||||
"""
|
||||
if self.inner_table is None:
|
||||
self.inner_table = [list(values.keys()), list(values.values())]
|
||||
else:
|
||||
columns = self.inner_table[0]
|
||||
if len(self.inner_table) == 1:
|
||||
# We give a chance to update the column names at the first iteration
|
||||
for key in values.keys():
|
||||
if key not in columns:
|
||||
columns.append(key)
|
||||
self.inner_table[0] = columns
|
||||
self.inner_table.append([values[c] for c in columns])
|
||||
|
||||
def add_child(self, total, prefix=None, width=300):
|
||||
"""
|
||||
Add a child progress bar disaplyed under the table of metrics. The child progress bar is returned (so it can
|
||||
be easily updated).
|
||||
|
||||
Args:
|
||||
total (:obj:`int`): The number of iterations for the child progress bar.
|
||||
prefix (:obj:`str`, `optional`): A prefix to write on the left of the progress bar.
|
||||
width (:obj:`int`, `optional`, defaults to 300): The width (in pixels) of the progress bar.
|
||||
"""
|
||||
self.child_bar = NotebookProgressBar(total, prefix=prefix, parent=self, width=width)
|
||||
return self.child_bar
|
||||
|
||||
def remove_child(self):
|
||||
"""
|
||||
Closes the child progress bar.
|
||||
"""
|
||||
self.child_bar = None
|
||||
self.display()
|
||||
|
||||
|
||||
class NotebookProgressCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation, optimized for
|
||||
Jupyter Notebooks or Google colab.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.training_tracker = None
|
||||
self.prediction_bar = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self.first_column = "Epoch" if args.max_steps <= 0 else "Step"
|
||||
self.training_loss = 0
|
||||
self.last_log = 0
|
||||
column_names = [self.first_column] + ["Training Loss", "Validation Loss"]
|
||||
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
|
||||
self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}")
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if self.prediction_bar is None:
|
||||
if self.training_tracker is not None:
|
||||
self.prediction_bar = self.training_tracker.add_child(len(eval_dataloader))
|
||||
else:
|
||||
self.prediction_bar = NotebookProgressBar(len(eval_dataloader))
|
||||
self.prediction_bar.update(1)
|
||||
else:
|
||||
self.prediction_bar.update(self.prediction_bar.value + 1)
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
if self.training_tracker is not None:
|
||||
values = {"Training Loss": "No log"}
|
||||
for log in reversed(state.log_history):
|
||||
if "loss" in log:
|
||||
values["Training Loss"] = log["loss"]
|
||||
break
|
||||
|
||||
if self.first_column == "Epoch":
|
||||
values["Epoch"] = int(state.epoch)
|
||||
else:
|
||||
values["Step"] = state.global_step
|
||||
values["Validation Loss"] = metrics["eval_loss"]
|
||||
_ = metrics.pop("total_flos", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
for k, v in metrics.items():
|
||||
if k == "eval_loss":
|
||||
values["Validation Loss"] = v
|
||||
else:
|
||||
splits = k.split("_")
|
||||
name = " ".join([part.capitalize() for part in splits[1:]])
|
||||
values[name] = v
|
||||
self.training_tracker.write_line(values)
|
||||
self.training_tracker.remove_child()
|
||||
self.prediction_bar = None
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
self.training_tracker.update(
|
||||
state.global_step, comment=f"Epoch {int(state.epoch)}/{state.num_train_epochs}", force_update=True
|
||||
)
|
||||
self.training_tracker = None
|
Loading…
Reference in New Issue