Fix num_heads in _upad_input (#26490)
* Fix num_heads in _upad_input The variable num_key_value_heads has falsely been named num_heads, which led to reshaping the query_layer using the wrong attention head count. (It would have been enough to use the correct variable self.num_heads instead of num_heads, but I renamed num_heads to num_key_value_heads for clarity) * fixed copies using make fix-copies and ran make fixup --------- Co-authored-by: fseiler <f.seiler@jerocom.de>
This commit is contained in:
parent
67239f7360
commit
ca0379b8c8
|
@ -692,13 +692,17 @@ class FalconFlashAttention2(FalconAttention):
|
|||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
||||
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
||||
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
|
|
|
@ -553,13 +553,17 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
||||
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
||||
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
|
|
Loading…
Reference in New Issue