fix: non-atomic checkpoint save (#27820)
This commit is contained in:
parent
fe8d1302c7
commit
4c5ed1d0c9
|
@ -2332,13 +2332,21 @@ class Trainer:
|
|||
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir, _internal_call=True)
|
||||
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0:
|
||||
logger.warning(
|
||||
f"Checkpoint destination directory {output_dir} already exists and is non-empty."
|
||||
"Saving will proceed but saved results may be invalid."
|
||||
)
|
||||
staging_output_dir = output_dir
|
||||
else:
|
||||
staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}")
|
||||
self.save_model(staging_output_dir, _internal_call=True)
|
||||
|
||||
if not self.args.save_only_model:
|
||||
# Save optimizer and scheduler
|
||||
self._save_optimizer_and_scheduler(output_dir)
|
||||
self._save_optimizer_and_scheduler(staging_output_dir)
|
||||
# Save RNG state
|
||||
self._save_rng_state(output_dir)
|
||||
self._save_rng_state(staging_output_dir)
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||
|
@ -2358,10 +2366,14 @@ class Trainer:
|
|||
|
||||
# Save the Trainer state
|
||||
if self.args.should_save:
|
||||
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
||||
self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME))
|
||||
|
||||
if self.args.push_to_hub:
|
||||
self._push_from_checkpoint(output_dir)
|
||||
self._push_from_checkpoint(staging_output_dir)
|
||||
|
||||
# Place checkpoint in final location after all saving is finished.
|
||||
if staging_output_dir != output_dir:
|
||||
os.rename(staging_output_dir, output_dir)
|
||||
|
||||
# Maybe delete some older checkpoints.
|
||||
if self.args.should_save:
|
||||
|
|
|
@ -79,7 +79,8 @@ from transformers.testing_utils import (
|
|||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
|
@ -1310,6 +1311,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
trainer.train()
|
||||
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
|
||||
|
||||
def test_save_checkpoints_is_atomic(self):
|
||||
class UnsaveableTokenizer(PreTrainedTokenizerBase):
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
raise OSError("simulated file write error")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
|
||||
# Attach unsaveable tokenizer to partially fail checkpointing
|
||||
trainer.tokenizer = UnsaveableTokenizer()
|
||||
with self.assertRaises(OSError) as _context:
|
||||
trainer.train()
|
||||
assert get_last_checkpoint(tmpdir) is None
|
||||
|
||||
@require_safetensors
|
||||
def test_safe_checkpoints(self):
|
||||
for save_safetensors in [True, False]:
|
||||
|
|
Loading…
Reference in New Issue