full cache length

This commit is contained in:
ydshieh 2024-05-09 23:26:43 +02:00
parent 69f6683b65
commit 1133b90a8d
1 changed files with 7 additions and 7 deletions

View File

@ -598,17 +598,17 @@ class GemmaSdpaAttention(GemmaAttention):
# length = 1 # length = 1
# The correct length # The correct length
length = _length # length = _length
# to use the full length of the static cache # to use the full length of the static cache
# _key_states = key_states _key_states = key_states
# _value_states = value_states _value_states = value_states
# _attn_mask = causal_mask if causal_mask is not None else causal_mask _attn_mask = causal_mask if causal_mask is not None else causal_mask
# to use the correct length or the very short length # to use the correct length or the very short length
_key_states = key_states[:, :, :length, :] # _key_states = key_states[:, :, :length, :]
_value_states = value_states[:, :, :length, :] # _value_states = value_states[:, :, :length, :]
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask # _attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,