Compare commits

...

1 Commits

Author SHA1 Message Date
ydshieh 6214d614b9 95b3c381 + dynamic length in static cache 2024-05-18 21:08:17 +02:00
1 changed files with 14 additions and 0 deletions

View File

@ -518,6 +518,7 @@ class GemmaSdpaAttention(GemmaAttention):
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@ -571,6 +572,11 @@ class GemmaSdpaAttention(GemmaAttention):
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
is_causal = True if causal_mask is None and q_len > 1 else False
if _length > 0:
key_states = key_states[:, :, :_length, :]
value_states = value_states[:, :, :_length, :]
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@ -616,6 +622,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@ -650,6 +657,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
**kwargs,
)
hidden_states = residual + hidden_states
@ -837,6 +845,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -915,6 +924,7 @@ class GemmaModel(GemmaPreTrainedModel):
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
_length=_length,
)
hidden_states = layer_outputs[0]
@ -1070,6 +1080,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1114,6 +1125,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
_length=_length,
)
hidden_states = outputs[0]
@ -1152,6 +1164,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
inputs_embeds=None,
cache_position=None,
use_cache=True,
_length=None,
**kwargs,
):
past_length = 0
@ -1218,6 +1231,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"_length": int(cache_position[-1]) + 1,
}
)
return model_inputs