diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 9c84a85f1c..eb1576197e 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1755,6 +1755,11 @@ class ProphetNetModel(ProphetNetPreTrainedModel): self.encoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings) + self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings) + def get_encoder(self): return self.encoder @@ -1876,6 +1881,10 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head) + def get_input_embeddings(self): return self.prophetnet.word_embeddings @@ -2070,7 +2079,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): PROPHETNET_START_DOCSTRING, ) class ProphetNetForCausalLM(ProphetNetPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = [ + "prophetnet.word_embeddings.weight", + "prophetnet.decoder.word_embeddings.weight", + "lm_head.weight", + ] def __init__(self, config: ProphetNetConfig): # set config for CLM @@ -2100,6 +2113,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) + def set_decoder(self, decoder): self.prophetnet.decoder = decoder @@ -2311,7 +2328,15 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): def __init__(self, config: ProphetNetConfig): super().__init__(config) - self.decoder = ProphetNetDecoder(config) + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings()) def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index faa5080b2d..f99cd4549a 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -1779,6 +1779,11 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel): self.encoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings) + self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings) + def get_encoder(self): return self.encoder @@ -1901,6 +1906,10 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head) + def get_input_embeddings(self): return self.prophetnet.word_embeddings @@ -2098,7 +2107,11 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): ) # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = [ + "prophetnet.word_embeddings.weight", + "prophetnet.decoder.word_embeddings.weight", + "lm_head.weight", + ] def __init__(self, config: XLMProphetNetConfig): # set config for CLM @@ -2128,6 +2141,10 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) + def set_decoder(self, decoder): self.prophetnet.decoder = decoder @@ -2340,7 +2357,15 @@ class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel): def __init__(self, config: XLMProphetNetConfig): super().__init__(config) - self.decoder = XLMProphetNetDecoder(config) + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.decoder = XLMProphetNetDecoder(config, word_embeddings=self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings()) def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 2649b8f41d..3f55ad9759 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -304,6 +304,25 @@ class Kosmos2ModelTest(ModelTesterMixin, unittest.TestCase): expected_arg_names = ["pixel_values"] self.assertListEqual(arg_names[:1], expected_arg_names) + def test_load_save_without_tied_weights(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.text_config.tie_word_embeddings = False + for model_class in self.all_model_classes: + model = model_class(config) + with tempfile.TemporaryDirectory() as d: + model.save_pretrained(d) + + model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) + # Checking the state dicts are correct + reloaded_state = model_reloaded.state_dict() + for k, v in model.state_dict().items(): + self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded") + torch.testing.assert_close( + v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" + ) + # Checking there was no complain of missing weights + self.assertEqual(infos["missing_keys"], []) + # overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers` def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f96812c36d..fdd48de2fd 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -76,7 +76,7 @@ from transformers.testing_utils import ( from transformers.utils import ( CONFIG_NAME, GENERATION_CONFIG_NAME, - WEIGHTS_NAME, + SAFE_WEIGHTS_NAME, is_accelerate_available, is_flax_available, is_tf_available, @@ -91,6 +91,7 @@ if is_accelerate_available(): if is_torch_available(): import torch + from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file from torch import nn @@ -311,17 +312,20 @@ class ModelTesterMixin: # check that certain keys didn't get saved with the model with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME) - state_dict_saved = torch.load(output_model_file) + output_model_file = os.path.join(tmpdirname, SAFE_WEIGHTS_NAME) + state_dict_saved = safe_load_file(output_model_file) + for k in _keys_to_ignore_on_save: self.assertNotIn(k, state_dict_saved.keys(), "\n".join(state_dict_saved.keys())) # Test we can load the state dict in the model, necessary for the checkpointing API in Trainer. load_result = model.load_state_dict(state_dict_saved, strict=False) - self.assertTrue( - len(load_result.missing_keys) == 0 - or set(load_result.missing_keys) == set(model._keys_to_ignore_on_save) - ) + keys_to_ignore = set(model._keys_to_ignore_on_save) + + if hasattr(model, "_tied_weights_keys"): + keys_to_ignore.update(set(model._tied_weights_keys)) + + self.assertTrue(len(load_result.missing_keys) == 0 or set(load_result.missing_keys) == keys_to_ignore) self.assertTrue(len(load_result.unexpected_keys) == 0) def test_gradient_checkpointing_backward_compatibility(self):