feat(wandb): logging and configuration improvements (#10826)

* feat: ensure unique artifact id

* feat: allow manual init

* fix: simplify reinit logic

* fix: no dropped value + immediate commits

* fix: wandb use in sagemaker

* docs: improve documenation and formatting

* fix: typos

* docs: improve formatting
This commit is contained in:
Boris Dayma 2021-03-22 09:45:17 -05:00 committed by GitHub
parent b230181d41
commit 125ccead71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 57 deletions

View File

@ -240,34 +240,11 @@ Whenever you use `Trainer` or `TFTrainer` classes, your losses, evaluation metri
Advanced configuration is possible by setting environment variables:
<table>
<thead>
<tr>
<th style="text-align:left">Environment Variables</th>
<th style="text-align:left">Options</th>
</tr>
</thead>
<tbody>
<tr>
<td style="text-align:left">WANDB_LOG_MODEL</td>
<td style="text-align:left">Log the model as artifact at the end of training (<b>false</b> by default)</td>
</tr>
<tr>
<td style="text-align:left">WANDB_WATCH</td>
<td style="text-align:left">
<ul>
<li><b>gradients</b> (default): Log histograms of the gradients</li>
<li><b>all</b>: Log histograms of gradients and parameters</li>
<li><b>false</b>: No gradient or parameter logging</li>
</ul>
</td>
</tr>
<tr>
<td style="text-align:left">WANDB_PROJECT</td>
<td style="text-align:left">Organize runs by project</td>
</tr>
</tbody>
</table>
| Environment Variable | Value |
|---|---|
| WANDB_LOG_MODEL | Log the model as artifact (log the model as artifact at the end of training (`false` by default) |
| WANDB_WATCH | one of `gradients` (default) to log histograms of gradients, `all` to log histograms of both gradients and parameters, or `false` for no histogram logging |
| WANDB_PROJECT | Organize runs by project |
Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`.

View File

@ -19,7 +19,6 @@ import io
import json
import numbers
import os
import re
import tempfile
from copy import deepcopy
from pathlib import Path
@ -559,20 +558,12 @@ class WandbCallback(TrainerCallback):
if has_wandb:
import wandb
wandb.ensure_configured()
if wandb.api.api_key is None:
has_wandb = False
logger.warning(
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
)
self._wandb = None
else:
self._wandb = wandb
self._wandb = wandb
self._initialized = False
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
def setup(self, args, state, model, reinit, **kwargs):
def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (`wandb`) integration.
@ -581,7 +572,8 @@ class WandbCallback(TrainerCallback):
Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training.
Whether or not to log model as artifact at the end of training. Use along with
`TrainingArguments.load_best_model_at_end` to upload best model.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
@ -610,13 +602,19 @@ class WandbCallback(TrainerCallback):
else:
run_name = args.run_name
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
reinit=reinit,
**init_args,
)
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args,
)
# add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)
# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
@ -628,23 +626,20 @@ class WandbCallback(TrainerCallback):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
self.setup(args, state, model, reinit=hp_search, **kwargs)
if hp_search:
self._wandb.finish()
if not self._initialized:
self.setup(args, state, model, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
# commit last step
if state.is_world_process_zero:
self._wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
metadata = (
{
k: v
@ -657,7 +652,7 @@ class WandbCallback(TrainerCallback):
"train/total_floss": state.total_flos,
}
)
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
@ -668,10 +663,10 @@ class WandbCallback(TrainerCallback):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model, reinit=False)
self.setup(args, state, model)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
self._wandb.log(logs, step=state.global_step)
self._wandb.log({**logs, "train/global_step": state.global_step})
class CometCallback(TrainerCallback):