Update modeling_tf_deberta.py (#13654)

Fixed expand_dims axis
This commit is contained in:
Kamal Raj 2021-09-20 20:41:04 +05:30 committed by GitHub
parent 04976a32dc
commit 936b3fdeaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -663,7 +663,7 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
if len(shape_list_pos) == 2:
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
elif len(shape_list_pos) == 3:
relative_pos = tf.expand_dims(relative_pos, 0)
relative_pos = tf.expand_dims(relative_pos, 1)
# bxhxqxk
elif len(shape_list_pos) != 4:
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")