full cache length
This commit is contained in:
parent
69f6683b65
commit
1133b90a8d
|
@ -598,17 +598,17 @@ class GemmaSdpaAttention(GemmaAttention):
|
||||||
# length = 1
|
# length = 1
|
||||||
|
|
||||||
# The correct length
|
# The correct length
|
||||||
length = _length
|
# length = _length
|
||||||
|
|
||||||
# to use the full length of the static cache
|
# to use the full length of the static cache
|
||||||
# _key_states = key_states
|
_key_states = key_states
|
||||||
# _value_states = value_states
|
_value_states = value_states
|
||||||
# _attn_mask = causal_mask if causal_mask is not None else causal_mask
|
_attn_mask = causal_mask if causal_mask is not None else causal_mask
|
||||||
|
|
||||||
# to use the correct length or the very short length
|
# to use the correct length or the very short length
|
||||||
_key_states = key_states[:, :, :length, :]
|
# _key_states = key_states[:, :, :length, :]
|
||||||
_value_states = value_states[:, :, :length, :]
|
# _value_states = value_states[:, :, :length, :]
|
||||||
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
|
# _attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
|
|
Loading…
Reference in New Issue