Add support for metadata format MLX (#29335)
Add support for loading safetensors files saved with metadata format mlx.
This commit is contained in:
parent
923733c22b
commit
45c0651090
|
@ -504,7 +504,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
|||
# Check format of the archive
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata.get("format") not in ["pt", "tf", "flax"]:
|
||||
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
|
|
Loading…
Reference in New Issue