Model templates (#10072)

This commit is contained in:
Lysandre Debut 2021-02-08 15:07:02 +01:00 committed by GitHub
parent 3b7e612a5e
commit c9df1b1d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -161,7 +161,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.rsqrt_att_head_size = 1.0 / math.sqrt(self.attention_head_size)
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
@ -201,8 +201,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
# attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.rsqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.multiply(attention_scores, dk)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TF{{cookiecutter.camelcase_modelname}}Model call() function)

View File

@ -593,7 +593,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
)
# Copied from transformers.models.bert.modeling_bert.BertPredictionHead with Bert->{{cookiecutter.camelcase_modelname}}
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}PredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()