Experimental loading of MLX files (#29511)
* Experimental loading of MLX files * Update exception message * Add test * Style * Use model from hf-internal-testing
This commit is contained in:
parent
73a27345d4
commit
b382a09e28
|
@ -3297,9 +3297,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
elif metadata.get("format") == "flax":
|
||||
from_flax = True
|
||||
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
|
||||
elif metadata.get("format") == "mlx":
|
||||
# This is a mlx file, we assume weights are compatible with pt
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}"
|
||||
f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
|
||||
)
|
||||
|
||||
from_pt = not (from_tf | from_flax)
|
||||
|
|
|
@ -1256,6 +1256,26 @@ class ModelUtilsTest(TestCasePlus):
|
|||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
|
||||
|
||||
@require_safetensors
|
||||
def test_model_from_pretrained_from_mlx(self):
|
||||
from safetensors import safe_open
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
self.assertEqual(metadata.get("format"), "pt")
|
||||
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
|
||||
input_ids = torch.randint(100, 1000, (1, 10))
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
outputs_from_saved = new_model(input_ids)
|
||||
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
|
Loading…
Reference in New Issue