Fix M4T weights tying (#27395)

fix seamless m4t weights tying
This commit is contained in:
Yoach Lacombe 2023-11-14 09:52:11 +00:00 committed by GitHub
parent e107ae364e
commit ee4fb326c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 0 deletions

View File

@ -2785,6 +2785,12 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
self.text_decoder.embed_tokens = value
self.shared = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING)
def forward(
self,
@ -3077,6 +3083,11 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
def set_input_embeddings(self, value):
self.text_decoder.embed_tokens = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING)
def forward(
self,
@ -3384,6 +3395,12 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
self.text_decoder.embed_tokens = value
self.shared = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING)
def forward(
self,
@ -3740,6 +3757,11 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
def set_input_embeddings(self, value):
self.text_decoder.embed_tokens = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING)
def forward(
self,
@ -4135,6 +4157,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING)
def forward(