diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1e863365b5..8836e0be21 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2415,7 +2415,7 @@ class Trainer: logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) if grad_norm is not None: - logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar