diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index da60c8b8f8..e183ff007b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4591,7 +4591,10 @@ class Trainer: >>> model = trainer.free_memory(model) ``` """ - *models, self.optimizer = release_memory(*models, self.optimizer) + # We need to have these references so they can be set to `None` + *models, self.optimizer, self.model, self.deepspeed, self.model_wrapped = release_memory( + *models, self.optimizer, self.model, self.deepspeed, self.model_wrapped + ) return models def propagate_args_to_deepspeed(self, auto_find_batch_size=False):