Fix computation of attention_probs when head_mask is provided. (#9853)

* Fix computation of attention_probs when head_mask is provided.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Apply changes to the template

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Funtowicz Morgan 2021-01-28 12:11:52 +01:00 committed by GitHub
parent b936582f71
commit 2ee9f9b69e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 4 additions and 4 deletions

View File

@ -370,7 +370,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

View File

@ -253,7 +253,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

View File

@ -377,7 +377,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

View File

@ -317,7 +317,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
# Mask heads if we want to
if head_mask is not None:
attention_scores = tf.multiply(attention_scores, head_mask)
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)