parent
216dff7549
commit
5e11d72d4d
|
@ -1184,6 +1184,11 @@ class MBartModel(MBartPreTrainedModel):
|
|||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
|
||||
|
||||
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
|
|
|
@ -327,6 +327,43 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_ensure_weights_are_shared(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
config.tie_word_embeddings = True
|
||||
model = MBartForConditionalGeneration(config)
|
||||
|
||||
# MBart shares four weights.
|
||||
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
|
||||
self.assertEqual(
|
||||
len(
|
||||
{
|
||||
model.get_output_embeddings().weight.data_ptr(),
|
||||
model.get_input_embeddings().weight.data_ptr(),
|
||||
model.base_model.decoder.embed_tokens.weight.data_ptr(),
|
||||
model.base_model.encoder.embed_tokens.weight.data_ptr(),
|
||||
}
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
config.tie_word_embeddings = False
|
||||
model = MBartForConditionalGeneration(config)
|
||||
|
||||
# MBart shares four weights.
|
||||
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
|
||||
self.assertEqual(
|
||||
len(
|
||||
{
|
||||
model.get_output_embeddings().weight.data_ptr(),
|
||||
model.get_input_embeddings().weight.data_ptr(),
|
||||
model.base_model.decoder.embed_tokens.weight.data_ptr(),
|
||||
model.base_model.encoder.embed_tokens.weight.data_ptr(),
|
||||
}
|
||||
),
|
||||
2,
|
||||
)
|
||||
|
||||
|
||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||
|
|
Loading…
Reference in New Issue