Fix template (#7040)
This commit is contained in:
parent
762cba3bda
commit
b482ad474a
|
@ -119,11 +119,15 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
def get_input_embeddings(self):
|
||||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
return self.embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
self.embeddings.vocab_size = value.shape[0]
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
|
raise NotImplementedError # Not implemented yet in the library for TF 2.0 models
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue