Distributed barrier before loading model (#10685)

This commit is contained in:
Sylvain Gugger 2021-03-15 08:28:15 -04:00 committed by GitHub
parent 339fc51acc
commit e12d6f513e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 0 deletions

View File

@ -1131,6 +1131,12 @@ class Trainer:
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif self.args.local_rank != -1:
dist.barrier()
logger.info(
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
)