From ca0379b8c8bb8c827b31b99da6066288450965e5 Mon Sep 17 00:00:00 2001 From: Florian Seiler <41786750+fs4r@users.noreply.github.com> Date: Mon, 2 Oct 2023 10:10:19 +0200 Subject: [PATCH] Fix num_heads in _upad_input (#26490) * Fix num_heads in _upad_input The variable num_key_value_heads has falsely been named num_heads, which led to reshaping the query_layer using the wrong attention head count. (It would have been enough to use the correct variable self.num_heads instead of num_heads, but I renamed num_heads to num_key_value_heads for clarity) * fixed copies using make fix-copies and ran make fixup --------- Co-authored-by: fseiler --- src/transformers/models/falcon/modeling_falcon.py | 12 ++++++++---- src/transformers/models/llama/modeling_llama.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 0d0742c89e..5c3de182b2 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -692,13 +692,17 @@ class FalconFlashAttention2(FalconAttention): # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0901d38a53..3d9e317750 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -553,13 +553,17 @@ class LlamaFlashAttention2(LlamaAttention): def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k