Use save_safetensor to disable safe serialization for XLA (#28669)
* Use save_safetensor to disable safe serialization for XLA https://github.com/huggingface/transformers/issues/28438 * Style fixup
This commit is contained in:
parent
3001543b94
commit
e3934198a3
|
@ -2907,13 +2907,19 @@ class Trainer:
|
|||
is_main_process=self.args.should_save,
|
||||
state_dict=model.state_dict(),
|
||||
save_function=xm.save,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
state_dict = model.state_dict()
|
||||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
|
||||
model.save_pretrained(
|
||||
output_dir,
|
||||
is_main_process=self.args.should_save,
|
||||
save_function=xm.save,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
if self.tokenizer is not None and self.args.should_save:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
|
Loading…
Reference in New Issue