Fixed issue #21039 and added test for low_cpu_mem_usage
This commit is contained in:
Susnato Dhar 2023-01-12 14:33:13 +05:30 committed by GitHub
parent e849e5bb4a
commit b5be744d3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 1 deletions

View File

@ -2629,7 +2629,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:
for key in missing_keys:
if key.startswith(prefix):
if key in list(model_state_dict.keys()):
key = key
elif f"{prefix}.key" in list(model_state_dict.keys()):
key = f"{prefix}.key"
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]

View File

@ -3166,6 +3166,27 @@ class ModelUtilsTest(TestCasePlus):
):
_ = ModelWithHead.from_pretrained(tmp_dir)
@require_torch_gpu
def test_pretrained_low_mem_new_config(self):
# Checking for 1 model(the same one which was described in the issue) .
model_ids = ["gpt2"]
for model_id in model_ids:
model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
model_config.n_layer = 48
model_config.n_head = 25
model_config.n_embd = 1600
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_id,
config=model_config,
ignore_mismatched_sizes=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
@require_torch
@is_staging_test