storing & logging gradient norm in trainer (#27326)
* report grad_norm during training * support getting grad_norm from deepspeed
This commit is contained in:
parent
a4851d9477
commit
4f09d0fd88
|
@ -198,6 +198,7 @@ if is_accelerate_available():
|
|||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate.utils import (
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
GradientAccumulationPlugin,
|
||||
load_fsdp_model,
|
||||
load_fsdp_optimizer,
|
||||
|
@ -1856,6 +1857,7 @@ class Trainer:
|
|||
self._total_loss_scalar = 0.0
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
model.zero_grad()
|
||||
grad_norm: Optional[float] = None
|
||||
|
||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||
|
||||
|
@ -1973,19 +1975,27 @@ class Trainer:
|
|||
# deepspeed does its own clipping
|
||||
|
||||
if is_sagemaker_mp_enabled() and args.fp16:
|
||||
self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||
elif self.use_apex:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
nn.utils.clip_grad_norm_(
|
||||
_grad_norm = nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
else:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
_grad_norm = self.accelerator.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
args.max_grad_norm,
|
||||
)
|
||||
|
||||
if (
|
||||
is_accelerate_available()
|
||||
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
|
||||
):
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
else:
|
||||
grad_norm = _grad_norm.item() if _grad_norm is not None else None
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||
|
@ -1999,7 +2009,7 @@ class Trainer:
|
|||
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
|
||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
|
||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
||||
else:
|
||||
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
|
||||
|
||||
|
@ -2019,7 +2029,7 @@ class Trainer:
|
|||
self.control.should_training_stop = True
|
||||
|
||||
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
|
||||
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
|
||||
|
||||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
|
||||
if is_torch_tpu_available():
|
||||
|
@ -2356,7 +2366,7 @@ class Trainer:
|
|||
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
|
||||
)
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
|
||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
|
||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||
if is_torch_tpu_available():
|
||||
xm.mark_step()
|
||||
|
@ -2370,6 +2380,8 @@ class Trainer:
|
|||
tr_loss -= tr_loss
|
||||
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
if grad_norm is not None:
|
||||
logs["grad_norm"] = grad_norm
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
self._total_loss_scalar += tr_loss_scalar
|
||||
|
|
Loading…
Reference in New Issue