feat(wandb): logging and configuration improvements (#10826)
* feat: ensure unique artifact id * feat: allow manual init * fix: simplify reinit logic * fix: no dropped value + immediate commits * fix: wandb use in sagemaker * docs: improve documenation and formatting * fix: typos * docs: improve formatting
This commit is contained in:
parent
b230181d41
commit
125ccead71
|
@ -240,34 +240,11 @@ Whenever you use `Trainer` or `TFTrainer` classes, your losses, evaluation metri
|
|||
|
||||
Advanced configuration is possible by setting environment variables:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style="text-align:left">Environment Variables</th>
|
||||
<th style="text-align:left">Options</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td style="text-align:left">WANDB_LOG_MODEL</td>
|
||||
<td style="text-align:left">Log the model as artifact at the end of training (<b>false</b> by default)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align:left">WANDB_WATCH</td>
|
||||
<td style="text-align:left">
|
||||
<ul>
|
||||
<li><b>gradients</b> (default): Log histograms of the gradients</li>
|
||||
<li><b>all</b>: Log histograms of gradients and parameters</li>
|
||||
<li><b>false</b>: No gradient or parameter logging</li>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="text-align:left">WANDB_PROJECT</td>
|
||||
<td style="text-align:left">Organize runs by project</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
| Environment Variable | Value |
|
||||
|---|---|
|
||||
| WANDB_LOG_MODEL | Log the model as artifact (log the model as artifact at the end of training (`false` by default) |
|
||||
| WANDB_WATCH | one of `gradients` (default) to log histograms of gradients, `all` to log histograms of both gradients and parameters, or `false` for no histogram logging |
|
||||
| WANDB_PROJECT | Organize runs by project |
|
||||
|
||||
Set run names with `run_name` argument present in scripts or as part of `TrainingArguments`.
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ import io
|
|||
import json
|
||||
import numbers
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
@ -559,20 +558,12 @@ class WandbCallback(TrainerCallback):
|
|||
if has_wandb:
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
has_wandb = False
|
||||
logger.warning(
|
||||
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
|
||||
)
|
||||
self._wandb = None
|
||||
else:
|
||||
self._wandb = wandb
|
||||
self._wandb = wandb
|
||||
self._initialized = False
|
||||
# log outputs
|
||||
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
||||
|
||||
def setup(self, args, state, model, reinit, **kwargs):
|
||||
def setup(self, args, state, model, **kwargs):
|
||||
"""
|
||||
Setup the optional Weights & Biases (`wandb`) integration.
|
||||
|
||||
|
@ -581,7 +572,8 @@ class WandbCallback(TrainerCallback):
|
|||
|
||||
Environment:
|
||||
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to log model as artifact at the end of training.
|
||||
Whether or not to log model as artifact at the end of training. Use along with
|
||||
`TrainingArguments.load_best_model_at_end` to upload best model.
|
||||
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
|
||||
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
|
||||
logging or :obj:`"all"` to log gradients and parameters.
|
||||
|
@ -610,13 +602,19 @@ class WandbCallback(TrainerCallback):
|
|||
else:
|
||||
run_name = args.run_name
|
||||
|
||||
self._wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||
config=combined_dict,
|
||||
name=run_name,
|
||||
reinit=reinit,
|
||||
**init_args,
|
||||
)
|
||||
if self._wandb.run is None:
|
||||
self._wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||
name=run_name,
|
||||
**init_args,
|
||||
)
|
||||
# add config parameters (run may have been created manually)
|
||||
self._wandb.config.update(combined_dict, allow_val_change=True)
|
||||
|
||||
# define default x-axis (for latest wandb versions)
|
||||
if getattr(self._wandb, "define_metric", None):
|
||||
self._wandb.define_metric("train/global_step")
|
||||
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
|
||||
|
||||
# keep track of model topology and gradients, unsupported on TPU
|
||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||
|
@ -628,23 +626,20 @@ class WandbCallback(TrainerCallback):
|
|||
if self._wandb is None:
|
||||
return
|
||||
hp_search = state.is_hyper_param_search
|
||||
if not self._initialized or hp_search:
|
||||
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
||||
if hp_search:
|
||||
self._wandb.finish()
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, **kwargs)
|
||||
|
||||
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
# commit last step
|
||||
if state.is_world_process_zero:
|
||||
self._wandb.log({})
|
||||
if self._log_model and self._initialized and state.is_world_process_zero:
|
||||
from .trainer import Trainer
|
||||
|
||||
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fake_trainer.save_model(temp_dir)
|
||||
# use run name and ensure it's a valid Artifact name
|
||||
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
|
||||
metadata = (
|
||||
{
|
||||
k: v
|
||||
|
@ -657,7 +652,7 @@ class WandbCallback(TrainerCallback):
|
|||
"train/total_floss": state.total_flos,
|
||||
}
|
||||
)
|
||||
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
||||
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with artifact.new_file(f.name, mode="wb") as fa:
|
||||
|
@ -668,10 +663,10 @@ class WandbCallback(TrainerCallback):
|
|||
if self._wandb is None:
|
||||
return
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, reinit=False)
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
logs = rewrite_logs(logs)
|
||||
self._wandb.log(logs, step=state.global_step)
|
||||
self._wandb.log({**logs, "train/global_step": state.global_step})
|
||||
|
||||
|
||||
class CometCallback(TrainerCallback):
|
||||
|
|
Loading…
Reference in New Issue