Use save_safetensor to disable safe serialization for XLA (#28669)

* Use save_safetensor to disable safe serialization for XLA

https://github.com/huggingface/transformers/issues/28438

* Style fixup
This commit is contained in:
jeffhataws 2024-01-24 03:57:45 -08:00 committed by Amy Roberts
parent 3001543b94
commit e3934198a3
1 changed files with 7 additions and 1 deletions

View File

@ -2907,13 +2907,19 @@ class Trainer:
is_main_process=self.args.should_save,
state_dict=model.state_dict(),
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
model.save_pretrained(
output_dir,
is_main_process=self.args.should_save,
save_function=xm.save,
safe_serialization=self.args.save_safetensors,
)
if self.tokenizer is not None and self.args.should_save:
self.tokenizer.save_pretrained(output_dir)