From b482ad474af2d2f63e04f3a405b51a430e1b4033 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 10 Sep 2020 14:45:52 +0200 Subject: [PATCH] Fix template (#7040) --- templates/adding_a_new_model/modeling_tf_xxx.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/templates/adding_a_new_model/modeling_tf_xxx.py b/templates/adding_a_new_model/modeling_tf_xxx.py index 52a2cb3c28..ea25d8fd75 100644 --- a/templates/adding_a_new_model/modeling_tf_xxx.py +++ b/templates/adding_a_new_model/modeling_tf_xxx.py @@ -119,11 +119,15 @@ class TFXxxMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] def _prune_heads(self, heads_to_prune): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models + raise NotImplementedError # Not implemented yet in the library for TF 2.0 models def call( self,