Experimental loading of MLX files (#29511)

* Experimental loading of MLX files

* Update exception message

* Add test

* Style

* Use model from hf-internal-testing
This commit is contained in:
Pedro Cuenca 2024-03-11 19:42:06 +01:00 committed by GitHub
parent 73a27345d4
commit b382a09e28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 1 deletions

View File

@ -3297,9 +3297,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
elif metadata.get("format") == "mlx":
# This is a mlx file, we assume weights are compatible with pt
pass
else:
raise ValueError(
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
)
from_pt = not (from_tf | from_flax)

View File

@ -1256,6 +1256,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
@require_safetensors
def test_model_from_pretrained_from_mlx(self):
from safetensors import safe_open
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx")
self.assertIsNotNone(model)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f:
metadata = f.metadata()
self.assertEqual(metadata.get("format"), "pt")
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
input_ids = torch.randint(100, 1000, (1, 10))
with torch.no_grad():
outputs = model(input_ids)
outputs_from_saved = new_model(input_ids)
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
@slow
@require_torch