Add support for metadata format MLX (#29335)

Add support for loading safetensors files saved with metadata format mlx.
This commit is contained in:
Alex Ishida 2024-03-07 22:51:59 +09:00 committed by GitHub
parent 923733c22b
commit 45c0651090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -504,7 +504,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax"]:
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."