From 096f2cf12664bb7da41f89897d3a22966baee9b4 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 16 Jun 2023 10:55:42 -0400 Subject: [PATCH] Tied weights load (#24310) * Use tied weight keys * More * Fix tied weight missing warning * Only give info on unexpected keys with different classes * Deal with empty archs * Fix tests * Refine test --- src/transformers/modeling_utils.py | 27 ++++----- tests/test_modeling_utils.py | 89 ++++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 20 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9a0894e4e2..4129c52702 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1779,10 +1779,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix for names in shared_ptrs.values(): # Removing the keys which are declared as known duplicates on # load. This allows to make sure the name which is kept is consistent. - if self._keys_to_ignore_on_load_missing is not None: + if self._tied_weights_keys is not None: found = 0 for name in sorted(names): - matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing) + matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) if matches_pattern and name in state_dict: found += 1 if found < len(names): @@ -3020,22 +3020,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix unexpected_keys = list(set(loaded_keys) - set(expected_keys)) if is_accelerate_available(): + model.tie_weights() tied_params = find_tied_parameters(model) else: tied_params = [] - _missing = [] - for k in missing_keys: - found = False - for group in tied_params: - if k in group: - found = True - if len(group) > 2: - group.remove(k) - else: - _missing.append(k) - if not found: - _missing.append(k) - missing_keys = _missing + + for group in tied_params: + missing_in_group = [k for k in missing_keys if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + missing_keys = [k for k in missing_keys if k not in missing_in_group] # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. @@ -3275,7 +3268,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix missing_keys = [elem for elem in missing_keys if "SCB" not in elem] if len(unexpected_keys) > 0: - logger.warning( + archs = [] if model.config.architectures is None else model.config.architectures + warner = logger.warn if model.__class__.__name__ in archs else logger.info + warner( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 760d6c6d54..3b441ec7e5 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -82,16 +82,31 @@ if is_torch_available(): # Fake pretrained models for tests class BaseModel(PreTrainedModel): + base_model_prefix = "base" config_class = PretrainedConfig def __init__(self, config): super().__init__(config) - self.linear = nn.Linear(4, 5) - self.linear_2 = nn.Linear(5, 6) + self.linear = nn.Linear(5, 5) + self.linear_2 = nn.Linear(5, 5) def forward(self, x): return self.linear_2(self.linear(x)) + class BaseModelWithTiedWeights(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, config): + super().__init__(config) + self.linear = nn.Linear(5, 5) + self.linear_2 = nn.Linear(5, 5) + + def forward(self, x): + return self.linear_2(self.linear(x)) + + def tie_weights(self): + self.linear_2.weight = self.linear.weight + class ModelWithHead(PreTrainedModel): base_model_prefix = "base" config_class = PretrainedConfig @@ -103,12 +118,30 @@ if is_torch_available(): super().__init__(config) self.base = BaseModel(config) # linear is a common name between Base and Head on purpose. - self.linear = nn.Linear(6, 3) - self.linear2 = nn.Linear(3, 5) + self.linear = nn.Linear(5, 5) + self.linear2 = nn.Linear(5, 5) def forward(self, x): return self.linear2(self.linear(self.base(x))) + class ModelWithHeadAndTiedWeights(PreTrainedModel): + base_model_prefix = "base" + config_class = PretrainedConfig + + def _init_weights(self, module): + pass + + def __init__(self, config): + super().__init__(config) + self.base = BaseModel(config) + self.decoder = nn.Linear(5, 5) + + def forward(self, x): + return self.decoder(self.base(x)) + + def tie_weights(self): + self.decoder.weight = self.base.linear.weight + TINY_T5 = "patrickvonplaten/t5-tiny-random" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" @@ -857,6 +890,54 @@ class ModelUtilsTest(TestCasePlus): ): _ = ModelWithHead.from_pretrained(tmp_dir) + def test_tied_weights_reload(self): + # Base + model = BaseModelWithTiedWeights(PretrainedConfig()) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + new_model = BaseModelWithTiedWeights.from_pretrained(tmp_dir) + self.assertIs(new_model.linear.weight, new_model.linear_2.weight) + + state_dict = model.state_dict() + # Remove tied weight from state_dict -> model should load with no complain of missing keys + del state_dict["linear_2.weight"] + torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) + new_model, load_info = BaseModelWithTiedWeights.from_pretrained(tmp_dir, output_loading_info=True) + self.assertListEqual(load_info["missing_keys"], []) + self.assertIs(new_model.linear.weight, new_model.linear_2.weight) + + # With head + model.save_pretrained(tmp_dir) + new_model, load_info = ModelWithHeadAndTiedWeights.from_pretrained(tmp_dir, output_loading_info=True) + self.assertIs(new_model.base.linear.weight, new_model.decoder.weight) + # Should only complain about the missing bias + self.assertListEqual(load_info["missing_keys"], ["decoder.bias"]) + + def test_unexpected_keys_warnings(self): + model = ModelWithHead(PretrainedConfig()) + logger = logging.get_logger("transformers.modeling_utils") + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + # Loading the model with a new class, we don't get a warning for unexpected weights, just an info + with CaptureLogger(logger) as cl: + _, loading_info = BaseModel.from_pretrained(tmp_dir, output_loading_info=True) + self.assertNotIn("were not used when initializing ModelWithHead", cl.out) + self.assertEqual( + set(loading_info["unexpected_keys"]), + {"linear.weight", "linear.bias", "linear2.weight", "linear2.bias"}, + ) + + # Loading the model with the same class, we do get a warning for unexpected weights + state_dict = model.state_dict() + state_dict["added_key"] = state_dict["linear.weight"] + torch.save(state_dict, os.path.join(tmp_dir, WEIGHTS_NAME)) + with CaptureLogger(logger) as cl: + _, loading_info = ModelWithHead.from_pretrained(tmp_dir, output_loading_info=True) + self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out) + self.assertEqual(loading_info["unexpected_keys"], ["added_key"]) + @require_torch_gpu @slow def test_pretrained_low_mem_new_config(self):