Fix mismatching loading in from_pretrained with/without accelerate (#28414)
* fix mismatching behavior in from_pretrained with/without accelerate * meaningful refactor * remove added space * add test * fix model on the hub * comment * use tiny model * style
This commit is contained in:
parent
002566f398
commit
66db33ddc8
|
@ -756,18 +756,23 @@ def _load_state_dict_into_meta_model(
|
|||
else:
|
||||
param = param.to(dtype)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
||||
if dtype is None:
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
if old_param is None:
|
||||
break
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
old_param = model
|
||||
splits = param_name.split(".")
|
||||
for split in splits:
|
||||
old_param = getattr(old_param, split)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if old_param is not None:
|
||||
if old_param is not None:
|
||||
if dtype is None:
|
||||
param = param.to(old_param.dtype)
|
||||
|
||||
if old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
|
||||
set_module_kwargs["value"] = param
|
||||
|
||||
if device_map is None:
|
||||
|
|
|
@ -34,6 +34,7 @@ from requests.exceptions import HTTPError
|
|||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
logging,
|
||||
|
@ -835,6 +836,23 @@ class ModelUtilsTest(TestCasePlus):
|
|||
outputs2 = new_model_with_offload(inputs)
|
||||
self.assertTrue(torch.allclose(outputs1[0].cpu(), outputs2[0].cpu()))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_from_pretrained_non_contiguous_checkpoint(self):
|
||||
# See: https://github.com/huggingface/transformers/pull/28414
|
||||
# Tiny models on the Hub have contiguous weights, contrarily to google/owlvit
|
||||
model = OwlViTForObjectDetection.from_pretrained("fxmarty/owlvit-tiny-non-contiguous-weight")
|
||||
self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
|
||||
|
||||
model = OwlViTForObjectDetection.from_pretrained(
|
||||
"fxmarty/owlvit-tiny-non-contiguous-weight", device_map="auto"
|
||||
)
|
||||
self.assertTrue(model.owlvit.visual_projection.weight.is_contiguous())
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
|
Loading…
Reference in New Issue