From c79bbc3ba54a81dab2eac13d89f264ed64cb2460 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 27 Apr 2022 12:28:42 -0400 Subject: [PATCH] Fix multiple deletions of the same files in save_pretrained (#16947) * Fix multiple deletions of the same files in save_pretrained * Add is_main_process argument --- .../run_image_classification_no_trainer.py | 14 ++++++++-- .../language-modeling/run_clm_no_trainer.py | 8 ++++-- .../language-modeling/run_mlm_no_trainer.py | 8 ++++-- .../multiple-choice/run_swag_no_trainer.py | 8 ++++-- .../run_qa_beam_search_no_trainer.py | 8 ++++-- .../question-answering/run_qa_no_trainer.py | 8 ++++-- .../run_semantic_segmentation_no_trainer.py | 14 ++++++++-- .../run_wav2vec2_pretraining_no_trainer.py | 8 ++++-- .../run_summarization_no_trainer.py | 8 ++++-- .../run_glue_no_trainer.py | 8 ++++-- .../run_ner_no_trainer.py | 8 ++++-- .../translation/run_translation_no_trainer.py | 8 ++++-- src/transformers/modeling_utils.py | 28 ++++++++++++++----- src/transformers/trainer.py | 4 +-- 14 files changed, 105 insertions(+), 35 deletions(-) diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index 40c38c6316..97bfe8e9cd 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -444,7 +444,11 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + ) if accelerator.is_main_process: feature_extractor.save_pretrained(args.output_dir) repo.push_to_hub( @@ -490,7 +494,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: feature_extractor.save_pretrained(args.output_dir) repo.push_to_hub( @@ -506,7 +512,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: feature_extractor.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index c7976ef350..8f14ee67c0 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -580,7 +580,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -596,7 +598,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 6ff58ddc7f..f05d92396e 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -627,7 +627,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -643,7 +645,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py index 0efdc8c2ec..d9f229f7e6 100755 --- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py +++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py @@ -588,7 +588,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -604,7 +606,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py index a6b95dfc2d..f7bba206a5 100644 --- a/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py @@ -817,7 +817,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -961,7 +963,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/question-answering/run_qa_no_trainer.py b/examples/pytorch/question-answering/run_qa_no_trainer.py index 1516f97b6e..c9e9b28efe 100755 --- a/examples/pytorch/question-answering/run_qa_no_trainer.py +++ b/examples/pytorch/question-answering/run_qa_no_trainer.py @@ -832,7 +832,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -930,7 +932,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py index 2902e23165..1e8d4ff8b6 100644 --- a/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py +++ b/examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py @@ -553,7 +553,11 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + ) if accelerator.is_main_process: feature_extractor.save_pretrained(args.output_dir) repo.push_to_hub( @@ -613,7 +617,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: feature_extractor.save_pretrained(args.output_dir) repo.push_to_hub( @@ -629,7 +635,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: feature_extractor.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py index 88021a4285..1b43f54110 100755 --- a/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py +++ b/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py @@ -669,7 +669,9 @@ def main(): if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process: repo.push_to_hub( @@ -720,7 +722,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: if args.push_to_hub: repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) diff --git a/examples/pytorch/summarization/run_summarization_no_trainer.py b/examples/pytorch/summarization/run_summarization_no_trainer.py index eb18648b3e..ab680d1d30 100644 --- a/examples/pytorch/summarization/run_summarization_no_trainer.py +++ b/examples/pytorch/summarization/run_summarization_no_trainer.py @@ -687,7 +687,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -703,7 +705,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index b0b6c9ce52..86906fe343 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -539,7 +539,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -555,7 +557,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/token-classification/run_ner_no_trainer.py b/examples/pytorch/token-classification/run_ner_no_trainer.py index 735ec5bd62..0eced00fec 100755 --- a/examples/pytorch/token-classification/run_ner_no_trainer.py +++ b/examples/pytorch/token-classification/run_ner_no_trainer.py @@ -691,7 +691,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -707,7 +709,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/examples/pytorch/translation/run_translation_no_trainer.py b/examples/pytorch/translation/run_translation_no_trainer.py index 9df81d65f6..120aa3e488 100644 --- a/examples/pytorch/translation/run_translation_no_trainer.py +++ b/examples/pytorch/translation/run_translation_no_trainer.py @@ -662,7 +662,9 @@ def main(): if args.push_to_hub and epoch < args.num_train_epochs - 1: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) repo.push_to_hub( @@ -678,7 +680,9 @@ def main(): if args.output_dir is not None: accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) if accelerator.is_main_process: tokenizer.save_pretrained(args.output_dir) if args.push_to_hub: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 36d3758f6b..a0e2b7466f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -20,6 +20,7 @@ import os import re import shutil import tempfile +import warnings from contextlib import contextmanager from dataclasses import dataclass from functools import partial @@ -1347,7 +1348,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix def save_pretrained( self, save_directory: Union[str, os.PathLike], - save_config: bool = True, + is_main_process: bool = True, state_dict: Optional[dict] = None, save_function: Callable = torch.save, push_to_hub: bool = False, @@ -1361,10 +1362,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. - save_config (`bool`, *optional*, defaults to `True`): - Whether or not to save the config of the model. Useful when in distributed training like TPUs and need - to call this function on all processes. In this case, set `save_config=True` only on the main process - to avoid race conditions. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. state_dict (nested dictionary of `torch.Tensor`): The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only save parts of the model or if special precautions need to be taken when recovering the state dictionary @@ -1397,6 +1398,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ + if "save_config" in kwargs: + warnings.warn( + "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." + ) + is_main_process = kwargs.pop("save_config") + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -1424,7 +1431,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix custom_object_save(self, save_directory, config=self.config) # Save the config - if save_config: + if is_main_process: model_to_save.config.save_pretrained(save_directory) # Save the model @@ -1443,7 +1450,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Clean the folder from a previous save for filename in os.listdir(save_directory): full_filename = os.path.join(save_directory, filename) - if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename): + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + if ( + filename.startswith(WEIGHTS_NAME[:-4]) + and os.path.isfile(full_filename) + and filename not in shards.keys() + and is_main_process + ): os.remove(full_filename) # Save the model diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index eb0a3ce2ee..eed6863138 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2173,7 +2173,7 @@ class Trainer: if isinstance(unwrap_model(self.model), PreTrainedModel): unwrap_model(self.model).save_pretrained( output_dir, - save_config=self.args.should_save, + is_main_process=self.args.should_save, state_dict=self.model.state_dict(), save_function=xm.save, ) @@ -2182,7 +2182,7 @@ class Trainer: state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir, save_config=self.args.should_save, save_function=xm.save) + self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) if self.tokenizer is not None and self.args.should_save: self.tokenizer.save_pretrained(output_dir)