Fix: save checkpoint after each epoch and push checkpoint to the hub (#13872)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2021-10-05 13:00:13 +02:00 committed by GitHub
parent 7079a99e76
commit a6ea244f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 8 deletions

View File

@ -769,6 +769,14 @@ def main():
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
# ======================== Prediction loop ==============================
if training_args.do_predict:
logger.info("*** Predict ***")
@ -808,14 +816,6 @@ def main():
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
logger.info(desc)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
if __name__ == "__main__":
main()