improve templates (#9342)

This commit is contained in:
Patrick von Platen 2020-12-29 16:48:44 +01:00 committed by GitHub
parent 64103fb6be
commit 785e52cd30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 4 deletions

View File

@ -174,6 +174,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
init_std=0.02,
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
{% endif -%}
pad_token_id=1,
@ -226,6 +227,8 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
{% endif -%}
{% if cookiecutter.is_encoder_decoder_model == "False" %}

View File

@ -1893,6 +1893,8 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
self.layerdrop = config.encoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
@ -1969,7 +1971,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["inputs_embeds"] is None:
inputs_embeds = self.embed_tokens(inputs["input_ids"])
inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
else:
inputs_embeds = inputs["inputs_embeds"]
@ -2038,6 +2040,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
self.padding_idx,
name="embed_positions",
)
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.layers = [TF{{cookiecutter.camelcase_modelname}}DecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
@ -2142,7 +2145,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs["inputs_embeds"] is None:
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
hidden_states = inputs["inputs_embeds"]

View File

@ -2093,6 +2093,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
@ -2167,7 +2168,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
@ -2236,6 +2237,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
@ -2337,7 +2339,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]