incorrectly very short cache length
This commit is contained in:
parent
1133b90a8d
commit
bd4c28b78c
|
@ -593,7 +593,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
# length = cache_position[-1] + 1
|
||||
|
||||
# incorrect results with torch.compile (index stays at the value obtained in the 2nd forward call)
|
||||
# length = self._seen_tokens
|
||||
length = self._seen_tokens
|
||||
# incorrect results without torch.compile (index stays at the value obtained in the 2nd forward call)
|
||||
# length = 1
|
||||
|
||||
|
@ -601,14 +601,14 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
# length = _length
|
||||
|
||||
# to use the full length of the static cache
|
||||
_key_states = key_states
|
||||
_value_states = value_states
|
||||
_attn_mask = causal_mask if causal_mask is not None else causal_mask
|
||||
# _key_states = key_states
|
||||
# _value_states = value_states
|
||||
# _attn_mask = causal_mask if causal_mask is not None else causal_mask
|
||||
|
||||
# to use the correct length or the very short length
|
||||
# _key_states = key_states[:, :, :length, :]
|
||||
# _value_states = value_states[:, :, :length, :]
|
||||
# _attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
|
||||
_key_states = key_states[:, :, :length, :]
|
||||
_value_states = value_states[:, :, :length, :]
|
||||
_attn_mask = causal_mask[:, :, :, :length] if causal_mask is not None else causal_mask
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
|
|
Loading…
Reference in New Issue