fix error: TypeError: Object of type Tensor is not JSON serializable … (#29568)

fix error: TypeError: Object of type Tensor is not JSON serializable trainer

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
yuanzhoulvpi 2024-03-12 01:15:36 +08:00 committed by GitHub
parent e5eb55b88b
commit 47c9570903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -2415,7 +2415,7 @@ class Trainer:
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar