From 936b3fdeaaf772a7858f118a960bb4e6710f90d2 Mon Sep 17 00:00:00 2001 From: Kamal Raj Date: Mon, 20 Sep 2021 20:41:04 +0530 Subject: [PATCH] Update modeling_tf_deberta.py (#13654) Fixed expand_dims axis --- src/transformers/models/deberta/modeling_tf_deberta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deberta/modeling_tf_deberta.py b/src/transformers/models/deberta/modeling_tf_deberta.py index d0cd8f6129..e85e8c4670 100644 --- a/src/transformers/models/deberta/modeling_tf_deberta.py +++ b/src/transformers/models/deberta/modeling_tf_deberta.py @@ -663,7 +663,7 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): if len(shape_list_pos) == 2: relative_pos = tf.expand_dims(tf.expand_dims(relative_pos, 0), 0) elif len(shape_list_pos) == 3: - relative_pos = tf.expand_dims(relative_pos, 0) + relative_pos = tf.expand_dims(relative_pos, 1) # bxhxqxk elif len(shape_list_pos) != 4: raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {len(shape_list_pos)}")