fix: non-atomic checkpoint save (#27820)

This commit is contained in:
Jonathon Belotti 2023-12-08 08:08:54 -05:00 committed by GitHub
parent fe8d1302c7
commit 4c5ed1d0c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 6 deletions

View File

@ -2332,13 +2332,21 @@ class Trainer:
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0:
logger.warning(
f"Checkpoint destination directory {output_dir} already exists and is non-empty."
"Saving will proceed but saved results may be invalid."
)
staging_output_dir = output_dir
else:
staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}")
self.save_model(staging_output_dir, _internal_call=True)
if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
self._save_optimizer_and_scheduler(staging_output_dir)
# Save RNG state
self._save_rng_state(output_dir)
self._save_rng_state(staging_output_dir)
# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
@ -2358,10 +2366,14 @@ class Trainer:
# Save the Trainer state
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME))
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)
self._push_from_checkpoint(staging_output_dir)
# Place checkpoint in final location after all saving is finished.
if staging_output_dir != output_dir:
os.rename(staging_output_dir, output_dir)
# Maybe delete some older checkpoints.
if self.args.should_save:

View File

@ -79,7 +79,8 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
@ -1310,6 +1311,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False)
def test_save_checkpoints_is_atomic(self):
class UnsaveableTokenizer(PreTrainedTokenizerBase):
def save_pretrained(self, *args, **kwargs):
raise OSError("simulated file write error")
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5)
# Attach unsaveable tokenizer to partially fail checkpointing
trainer.tokenizer = UnsaveableTokenizer()
with self.assertRaises(OSError) as _context:
trainer.train()
assert get_last_checkpoint(tmpdir) is None
@require_safetensors
def test_safe_checkpoints(self):
for save_safetensors in [True, False]: