improve templates (#9342)
This commit is contained in:
parent
64103fb6be
commit
785e52cd30
|
@ -174,6 +174,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||
init_std=0.02,
|
||||
decoder_start_token_id=2,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
gradient_checkpointing=False,
|
||||
{% endif -%}
|
||||
pad_token_id=1,
|
||||
|
@ -226,6 +227,8 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
{% endif -%}
|
||||
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
|
|
|
@ -1893,6 +1893,8 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||
self.layerdrop = config.encoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||
|
@ -1969,7 +1971,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs_embeds = self.embed_tokens(inputs["input_ids"])
|
||||
inputs_embeds = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
else:
|
||||
inputs_embeds = inputs["inputs_embeds"]
|
||||
|
||||
|
@ -2038,6 +2040,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.layers = [TF{{cookiecutter.camelcase_modelname}}DecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
|
||||
|
@ -2142,7 +2145,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs["inputs_embeds"] is None:
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"])
|
||||
inputs["inputs_embeds"] = self.embed_tokens(inputs["input_ids"]) * self.embed_scale
|
||||
|
||||
hidden_states = inputs["inputs_embeds"]
|
||||
|
||||
|
|
|
@ -2093,6 +2093,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
|
@ -2167,7 +2168,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
|
||||
|
@ -2236,6 +2237,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
|
@ -2337,7 +2339,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
|
Loading…
Reference in New Issue