Fix timeout propagation for world_process_zero things

This commit is contained in:
Zach Mueller 2024-03-07 14:34:59 -05:00
parent 4ed9ae623d
commit ac62a075f0
1 changed files with 14 additions and 0 deletions

View File

@ -2524,6 +2524,7 @@ class Trainer:
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
self.wait_for_everyone()
# Maybe delete some older checkpoints.
if self.args.should_save:
@ -3659,6 +3660,7 @@ class Trainer:
"""
# Only on process zero
if not self.is_world_process_zero():
self.wait_for_everyone()
return
if self.args.hub_model_id is None:
@ -3669,6 +3671,7 @@ class Trainer:
repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
self.hub_model_id = repo_url.repo_id
self.push_in_progress = None
self.wait_for_everyone()
def create_model_card(
self,
@ -3708,6 +3711,7 @@ class Trainer:
One or several dataset arguments, to be included in the metadata of the model card.
"""
if not self.is_world_process_zero():
self.wait_for_everyone()
return
model_card_filepath = os.path.join(self.args.output_dir, "README.md")
@ -3743,6 +3747,7 @@ class Trainer:
if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
self.wait_for_everyone()
def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
@ -4145,3 +4150,12 @@ class Trainer:
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
def wait_for_everyone(self):
"""
Alias for `accelerator.wait_for_everyone`.
If `accelerate` is not installed, does nothing.
"""
if is_accelerate_available():
self.accelerator.wait_for_everyone()