parent
04976a32dc
commit
936b3fdeaa
|
@ -663,7 +663,7 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
|
|||
if len(shape_list_pos) == 2:
|
||||
relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0)
|
||||
elif len(shape_list_pos) == 3:
|
||||
relative_pos = tf.expand_dims(relative_pos, 0)
|
||||
relative_pos = tf.expand_dims(relative_pos, 1)
|
||||
# bxhxqxk
|
||||
elif len(shape_list_pos) != 4:
|
||||
raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")
|
||||
|
|
Loading…
Reference in New Issue