diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5743634212..95b5d83b68 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2524,6 +2524,7 @@ class Trainer: if self.args.push_to_hub: self._push_from_checkpoint(output_dir) + self.wait_for_everyone() # Maybe delete some older checkpoints. if self.args.should_save: @@ -3659,6 +3660,7 @@ class Trainer: """ # Only on process zero if not self.is_world_process_zero(): + self.wait_for_everyone() return if self.args.hub_model_id is None: @@ -3669,6 +3671,7 @@ class Trainer: repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) self.hub_model_id = repo_url.repo_id self.push_in_progress = None + self.wait_for_everyone() def create_model_card( self, @@ -3708,6 +3711,7 @@ class Trainer: One or several dataset arguments, to be included in the metadata of the model card. """ if not self.is_world_process_zero(): + self.wait_for_everyone() return model_card_filepath = os.path.join(self.args.output_dir, "README.md") @@ -3743,6 +3747,7 @@ class Trainer: if is_peft_library: unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + self.wait_for_everyone() def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. @@ -4145,3 +4150,12 @@ class Trainer: ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size) + + def wait_for_everyone(self): + """ + Alias for `accelerator.wait_for_everyone`. + + If `accelerate` is not installed, does nothing. + """ + if is_accelerate_available(): + self.accelerator.wait_for_everyone()