diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 850333ca8d..f7e982c1ca 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -614,13 +614,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): def get_output_embeddings(self) -> tf.keras.layers.Layer: """ - Returns the model's output embeddings. + Returns the model's output embeddings Returns: :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary. """ return None # Overwrite for models with output embeddings + def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]: + """ + Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the + embeddings. + + Return: + :obj:`tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model. + """ + return None + + def get_prefix_bias_name(self) -> Union[None, str]: + """ + Get the concatenated prefix name of the bias from the model name to the parent layer. + + Return: + :obj:`str`: The prefix name of the bias. + """ + return None + def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: """ Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`. @@ -662,7 +681,17 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): # TFSharedEmbeddings return embeddings.weight else: - raise ValueError("word embedding is not defined.") + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + embeddings.build([]) + if hasattr(embeddings, "word_embeddings"): + # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings + return embeddings.word_embeddings + elif hasattr(embeddings, "weight"): + # TFSharedEmbeddings + return embeddings.weight + else: + raise ValueError("word embedding is not defined.") def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: """ @@ -684,28 +713,87 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): :obj:`new_num_tokens` is :obj:`None` """ word_embeddings = self._get_word_embeddings(old_embeddings) + bias_layer = self.get_output_layer_with_bias() + if new_num_tokens is None: return word_embeddings + old_num_tokens, old_embedding_dim = word_embeddings.shape + if old_num_tokens == new_num_tokens: return word_embeddings # initialize new embeddings # todo: initializer range is not always passed in config. init_range = getattr(self.config, "initializer_range", 0.02) + name = ( + self.name + + "/" + + self.base_model_prefix + + "/" + + old_embeddings.name + + "/" + + word_embeddings.name.split(":")[0] + ) new_embeddings = self.add_weight( - "weight", + name=name, shape=[new_num_tokens, old_embedding_dim], initializer=get_initializer(init_range), dtype=tf.float32, ) - init_weights = new_embeddings.numpy() + init_weights = tf.make_ndarray(tf.make_tensor_proto(new_embeddings.value())) # Copy token embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :] + init_weights[:num_tokens_to_copy] = word_embeddings.value()[:num_tokens_to_copy, :] new_embeddings.assign(init_weights) + if bias_layer is not None: + if not hasattr(bias_layer, "bias"): + bias_layer.build([]) + + # Second check in order to be sure the attribute has been properly created + if not hasattr(bias_layer, "bias"): + raise ValueError("bias is not defined.") + + # initialize bias + init_bias = np.zeros((new_num_tokens,)) + init_bias[:num_tokens_to_copy] = bias_layer.bias.value()[ + :num_tokens_to_copy + ] # tf.make_ndarray(tf.make_tensor_proto(bias_layer.bias.value()))[:num_tokens_to_copy] + + bias_layer.bias = self.add_weight( + shape=(new_num_tokens,), + initializer="zeros", + trainable=True, + name=self.get_prefix_bias_name() + "/bias", + ) + + bias_layer.bias.assign(init_bias) + + output_embeddings = self.get_output_embeddings() + + if output_embeddings is not None: + if self.get_input_embeddings() != output_embeddings: + if not hasattr(output_embeddings, "decoder"): + output_embeddings.build([]) + + # Second check in order to be sure the attribute has been properly created + if not hasattr(output_embeddings, "decoder"): + raise ValueError("decoder is not defined.") + + # initialize decoder + init_weights = np.zeros((new_num_tokens, old_embedding_dim)) + init_weights[:num_tokens_to_copy] = output_embeddings.decoder.value()[:num_tokens_to_copy, :] + + output_embeddings.decoder = self.add_weight( + shape=(new_num_tokens, old_embedding_dim), + initializer="zeros", + trainable=True, + name=self.get_prefix_bias_name() + "/decoder/weight", + ) + output_embeddings.decoder.assign(init_weights) + return new_embeddings def prune_heads(self, heads_to_prune): diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index 671f196296..222e815b80 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -467,6 +467,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): self.decoder_bias = self.add_weight( shape=(self.vocab_size,), initializer="zeros", trainable=True, name="decoder/bias" ) + super().build(input_shape) def call(self, hidden_states): @@ -825,6 +826,32 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): def get_output_embeddings(self): return self.albert.embeddings + def resize_token_embeddings(self, new_num_tokens): + super().resize_token_embeddings(new_num_tokens=new_num_tokens) + + # ALBERT is a special case where there are two bias to update + # even though self.bias is not used anywhere and is here + # just to make the loading weights from a PT model happy + if new_num_tokens is not None: + num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens) + self.predictions.vocab_size = num_tokens_to_copy + init_bias = tf.zeros((new_num_tokens,)) + init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy] + name = self.name + "/" + self.predictions.name + "/bias" + self.predictions.bias = self.add_weight( + shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name + ) + self.predictions.bias.assign(init_bias) + + init_decoder_bias = tf.zeros((new_num_tokens,)) + init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy] + name = self.name + "/" + self.predictions.name + "/decoder_bias" + self.predictions.decoder_bias = self.add_weight( + shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name + ) + + self.predictions.decoder_bias.assign(init_decoder_bias) + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -933,6 +960,32 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss) def get_output_embeddings(self): return self.albert.embeddings + def resize_token_embeddings(self, new_num_tokens): + super().resize_token_embeddings(new_num_tokens=new_num_tokens) + + # ALBERT is a special case where there are two bias to update + # even though self.bias is not used anywhere and is here + # just to make the loading weights from a PT model happy + if new_num_tokens is not None: + num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens) + self.predictions.vocab_size = num_tokens_to_copy + init_bias = tf.zeros((new_num_tokens,)) + init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy] + name = self.name + "/" + self.predictions.name + "/bias" + self.predictions.bias = self.add_weight( + shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name + ) + self.predictions.bias.assign(init_bias) + + init_decoder_bias = tf.zeros((new_num_tokens,)) + init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy] + name = self.name + "/" + self.predictions.name + "/decoder_bias" + self.predictions.decoder_bias = self.add_weight( + shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name + ) + + self.predictions.decoder_bias.assign(init_decoder_bias) + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 17d2308df9..88af7f8336 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1049,6 +1049,24 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False ) + def resize_token_embeddings(self, new_num_tokens): + super().resize_token_embeddings(new_num_tokens=new_num_tokens) + + # BART is a special case where the bias has two dimensions + # and not named just `bias` + if new_num_tokens is not None: + num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens) + init_bias = tf.zeros((new_num_tokens,)) + init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy] + name = self.name + "/final_logits_bias" + self.final_logits_bias = self.add_weight( + shape=(1, new_num_tokens), + initializer="zeros", + trainable=False, + name=name, + ) + self.final_logits_bias.assign(init_bias) + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def call( diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 61d9dad62a..987b1d9dc0 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -893,6 +893,12 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): def get_output_embeddings(self): return self.bert.embeddings + def get_output_layer_with_bias(self): + return self.mlm.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -1002,6 +1008,12 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): def get_output_embeddings(self): return self.bert.embeddings + def get_output_layer_with_bias(self): + return self.mlm.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -1095,6 +1107,12 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): def get_output_embeddings(self): return self.bert.embeddings + def get_output_layer_with_bias(self): + return self.mlm.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-cased", diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 6d1680bda2..abbb5d0a57 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -629,6 +629,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): def get_output_embeddings(self): return self.lm_head.input_embeddings + def get_output_layer_with_bias(self): + return self.lm_head + + def get_prefix_bias_name(self): + return self.name + "/" + self.lm_head.name + def prepare_inputs_for_generation(self, inputs, past, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past: diff --git a/src/transformers/models/distilbert/modeling_tf_distilbert.py b/src/transformers/models/distilbert/modeling_tf_distilbert.py index 9e887b4a1c..1f5a1c7d8c 100644 --- a/src/transformers/models/distilbert/modeling_tf_distilbert.py +++ b/src/transformers/models/distilbert/modeling_tf_distilbert.py @@ -655,6 +655,12 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModel def get_output_embeddings(self): return self.vocab_projector.input_embeddings + def get_output_layer_with_bias(self): + return self.vocab_projector + + def get_prefix_bias_name(self): + return self.name + "/" + self.vocab_projector.name + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index c97b4ed373..3a39b03762 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -882,7 +882,7 @@ class TFElectraMaskedLMHead(tf.keras.layers.Layer): super().build(input_shape) - def call(self, hidden_states): + def call(self, hidden_states, training=False): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -914,8 +914,14 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") def get_output_embeddings(self): + return self.electra.embeddings + + def get_output_layer_with_bias(self): return self.generator_lm_head + def get_prefix_bias_name(self): + return self.name + "/" + self.generator_lm_head.name + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index 71b21780d7..c1711b7f73 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -766,6 +766,12 @@ class TFFlaubertWithLMHeadModel(TFFlaubertPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.input_embeddings + def get_output_layer_with_bias(self): + return self.pred_layer + + def get_prefix_bias_name(self): + return self.name + "/" + self.pred_layer.name + def prepare_inputs_for_generation(self, inputs, **kwargs): mask_token_id = self.config.mask_token_id lang_id = self.config.lang_id diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 57368134ea..38208112bf 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -1320,6 +1320,15 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss) self.funnel = TFFunnelMainLayer(config, name="funnel") self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head") + def get_output_embeddings(self): + return self.funnel.embeddings + + def get_output_layer_with_bias(self): + return self.lm_head + + def get_prefix_bias_name(self): + return self.name + "/" + self.lm_head.name + @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 71595b0564..db30435be5 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -2009,6 +2009,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel def get_output_embeddings(self): return self.lm_head.decoder + def get_output_layer_with_bias(self): + return self.lm_head + + def get_prefix_bias_name(self): + return self.name + "/" + self.lm_head.name + @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/lxmert/modeling_tf_lxmert.py b/src/transformers/models/lxmert/modeling_tf_lxmert.py index bfe159c154..43cd3b5fc4 100644 --- a/src/transformers/models/lxmert/modeling_tf_lxmert.py +++ b/src/transformers/models/lxmert/modeling_tf_lxmert.py @@ -1257,6 +1257,15 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): **({"obj_labels": obj_labels} if self.config.task_obj_predict else {}), } + def get_output_embeddings(self): + return self.lxmert.embeddings + + def get_output_layer_with_bias(self): + return self.cls.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name + @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFLxmertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def call( diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index 3b637cc4e1..2891223ad3 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -702,6 +702,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.embeddings + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -1024,7 +1028,13 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls") def get_output_embeddings(self): - return self.mobilebert.embeddings + return self.predictions.predictions + + def get_output_layer_with_bias(self): + return self.predictions.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFMobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1117,7 +1127,13 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") def get_output_embeddings(self): - return self.mobilebert.embeddings + return self.mlm.predictions + + def get_output_layer_with_bias(self): + return self.mlm.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py index b65c133f09..02f462572d 100644 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ b/src/transformers/models/mpnet/modeling_tf_mpnet.py @@ -830,6 +830,12 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss): def get_output_embeddings(self): return self.mpnet.embeddings + def get_output_layer_with_bias(self): + return self.lm_head + + def get_prefix_bias_name(self): + return self.name + "/" + self.lm_head.name + @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 3764efac83..ae5f3dd223 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -810,6 +810,12 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos def get_output_embeddings(self): return self.lm_head.decoder + def get_output_layer_with_bias(self): + return self.lm_head + + def get_prefix_bias_name(self): + return self.name + "/" + self.lm_head.name + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index fc20415a2f..c03022e141 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -803,6 +803,12 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel): def get_output_embeddings(self): return self.pred_layer.input_embeddings + def get_output_layer_with_bias(self): + return self.pred_layer + + def get_prefix_bias_name(self): + return self.name + "/" + self.pred_layer.name + def prepare_inputs_for_generation(self, inputs, **kwargs): mask_token_id = self.config.mask_token_id lang_id = self.config.lang_id diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index e678944d6c..7e72df370a 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -1221,6 +1221,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): def get_output_embeddings(self): return self.lm_loss.input_embeddings + def get_output_layer_with_bias(self): + return self.lm_loss + + def get_prefix_bias_name(self): + return self.name + "/" + self.lm_loss.name + def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs): # Add dummy token at the end (no attention on this one) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 9cd82c00d1..15ac9571ba 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -772,10 +772,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls") - + def get_output_embeddings(self): return self.{{cookiecutter.lowercase_modelname}}.embeddings + def get_output_layer_with_bias(self): + return self.mlm.predictions + + def get_prefix_bias_name(self): + return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name + @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/tests/test_modeling_tf_albert.py b/tests/test_modeling_tf_albert.py index 3dec2837bd..354e116671 100644 --- a/tests/test_modeling_tf_albert.py +++ b/tests/test_modeling_tf_albert.py @@ -272,6 +272,17 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @slow def test_model_from_pretrained(self): for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index c31523612f..99c3d03eca 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -126,6 +126,17 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): # Should be uncommented during patrick TF refactor pass + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @require_tf class TFBartHeadTests(unittest.TestCase): diff --git a/tests/test_modeling_tf_bert.py b/tests/test_modeling_tf_bert.py index a8ca5e3022..a1d2bb747a 100644 --- a/tests/test_modeling_tf_bert.py +++ b/tests/test_modeling_tf_bert.py @@ -331,6 +331,25 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random") self.assertIsNotNone(model) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_lm_models = [TFBertForMaskedLM, TFBertForPreTraining, TFBertLMHeadModel] + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + + if model_class in list_lm_models: + x = model.get_output_layer_with_bias() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_prefix_bias_name() + assert isinstance(name, str) + else: + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert x is None + def test_custom_load_tf_weights(self): model, output_loading_info = TFBertForTokenClassification.from_pretrained( "jplu/tiny-tf-bert-random", output_loading_info=True diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index 2dbd14ef9e..7b2f4196c8 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -18,17 +18,17 @@ import unittest from tests.test_configuration_common import ConfigTester from tests.test_modeling_tf_bart import TFBartModelTester from tests.test_modeling_tf_common import TFModelTesterMixin -from transformers import ( - BlenderbotConfig, - BlenderbotSmallTokenizer, - TFAutoModelForSeq2SeqLM, - TFBlenderbotForConditionalGeneration, - is_tf_available, -) +from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available from transformers.file_utils import cached_property from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow +if is_tf_available(): + import tensorflow as tf + + from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration + + class TFBlenderbotModelTester(TFBartModelTester): config_updates = dict( normalize_before=True, @@ -65,6 +65,17 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): # Should be uncommented during patrick TF refactor pass + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @is_pt_tf_cross_test @require_tokenizers diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 0e06ee757c..5aa1e78e17 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -592,12 +592,26 @@ class TFModelTesterMixin: def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_lm_models = ( + list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values()) + + list(TF_MODEL_FOR_MASKED_LM_MAPPING.values()) + + list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values()) + ) for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding)) - x = model.get_output_embeddings() - assert x is None or isinstance(x, tf.keras.layers.Layer) + + if model_class in list_lm_models: + x = model.get_output_layer_with_bias() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_prefix_bias_name() + assert isinstance(name, str) + else: + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert x is None def test_determinism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index fe8eb7e4c8..07d2f8ae65 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -353,6 +353,17 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + def test_gpt2_sequence_classification_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs) diff --git a/tests/test_modeling_tf_lxmert.py b/tests/test_modeling_tf_lxmert.py index 047b71c324..1c90ec5e18 100644 --- a/tests/test_modeling_tf_lxmert.py +++ b/tests/test_modeling_tf_lxmert.py @@ -678,6 +678,25 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): extended_model = tf.keras.Model(inputs=[input_ids, visual_feats, visual_pos], outputs=[outputs]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_lm_models = [TFLxmertForPreTraining] + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + + if model_class in list_lm_models: + x = model.get_output_layer_with_bias() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_prefix_bias_name() + assert isinstance(name, str) + else: + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert x is None + @slow def test_saved_model_with_hidden_states_output(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index a713023d4f..b4ce498706 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -94,6 +94,17 @@ class TestTFMarianCommon(TFModelTesterMixin, unittest.TestCase): extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + class AbstractMarianIntegrationTest(unittest.TestCase): maxDiff = 1000 # show more chars for failing integration tests diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index d631971c43..80a7e91154 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -93,6 +93,17 @@ class TestTFMBartCommon(TFModelTesterMixin, unittest.TestCase): extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @is_pt_tf_cross_test @require_sentencepiece diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py index e170798b37..939d2a4235 100644 --- a/tests/test_modeling_tf_mobilebert.py +++ b/tests/test_modeling_tf_mobilebert.py @@ -283,6 +283,25 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mobilebert_for_token_classification(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_lm_models = [TFMobileBertForMaskedLM, TFMobileBertForPreTraining] + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + + if model_class in list_lm_models: + x = model.get_output_layer_with_bias() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_prefix_bias_name() + assert isinstance(name, str) + else: + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert x is None + @slow def test_model_from_pretrained(self): # for model_name in TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index 4cb71670e0..7eb9e316c3 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -202,6 +202,17 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_openai_gpt_double_head(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @slow def test_model_from_pretrained(self): for model_name in TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index fab6c34373..b6e16f75d5 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -99,6 +99,17 @@ class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase): extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @is_pt_tf_cross_test @require_sentencepiece diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 9854e28fd2..64bb41bef1 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -282,6 +282,17 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @slow def test_model_from_pretrained(self): model = TFT5Model.from_pretrained("t5-small") diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index f8a8cc4d24..94167bbac5 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -163,6 +163,17 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + x = model.get_output_layer_with_bias() + assert x is None + name = model.get_prefix_bias_name() + assert name is None + @slow def test_model_from_pretrained(self): for model_name in TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: