Fix computation of attention_probs when head_mask is provided. (#9853)
* Fix computation of attention_probs when head_mask is provided. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Apply changes to the template Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
b936582f71
commit
2ee9f9b69e
|
@ -370,7 +370,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_scores = tf.multiply(attention_scores, head_mask)
|
||||
attention_probs = tf.multiply(attention_probs, head_mask)
|
||||
|
||||
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
|
|
@ -253,7 +253,7 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
|
|||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_scores = tf.multiply(attention_scores, head_mask)
|
||||
attention_probs = tf.multiply(attention_probs, head_mask)
|
||||
|
||||
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
|
|
@ -377,7 +377,7 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
|
|||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_scores = tf.multiply(attention_scores, head_mask)
|
||||
attention_probs = tf.multiply(attention_probs, head_mask)
|
||||
|
||||
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
|
|
@ -317,7 +317,7 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
|
|||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_scores = tf.multiply(attention_scores, head_mask)
|
||||
attention_probs = tf.multiply(attention_probs, head_mask)
|
||||
|
||||
attention_output = tf.einsum("acbe,aecd->abcd", attention_probs, value_layer)
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
|
Loading…
Reference in New Issue