Fix: save checkpoint after each epoch and push checkpoint to the hub (#13872)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
7079a99e76
commit
a6ea244f99
|
@ -769,6 +769,14 @@ def main():
|
|||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||
|
||||
# ======================== Prediction loop ==============================
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
@ -808,14 +816,6 @@ def main():
|
|||
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
||||
logger.info(desc)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue