storing & logging gradient norm in trainer (#27326)

* report grad_norm during training

* support getting grad_norm from deepspeed
This commit is contained in:
Shijie Wu 2024-02-19 14:07:41 -05:00 committed by GitHub
parent a4851d9477
commit 4f09d0fd88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 6 deletions

View File

@ -198,6 +198,7 @@ if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate.utils import (
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
load_fsdp_model,
load_fsdp_optimizer,
@ -1856,6 +1857,7 @@ class Trainer:
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
@ -1973,19 +1975,27 @@ class Trainer:
# deepspeed does its own clipping
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
_grad_norm = nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None
# Optimizer step
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
@ -1999,7 +2009,7 @@ class Trainer:
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
@ -2019,7 +2029,7 @@ class Trainer:
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_tpu_available():
@ -2356,7 +2366,7 @@ class Trainer:
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_tpu_available():
xm.mark_step()
@ -2370,6 +2380,8 @@ class Trainer:
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar