Compare commits
4 Commits
main
...
full_lengt
Author | SHA1 | Date |
---|---|---|
ydshieh | 7f2ebb5fe9 | |
ydshieh | 68b71c85e1 | |
ydshieh | 862cde4ce8 | |
ydshieh | af2e273e2d |
|
@ -518,6 +518,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
|
@ -571,6 +572,11 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
# inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True`
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
# if _length > 0 and isinstance(past_key_value, StaticCache):
|
||||
# key_states = key_states[:, :, :_length, :]
|
||||
# value_states = value_states[:, :, :_length, :]
|
||||
# causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
|
@ -616,6 +622,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
|
@ -650,6 +657,7 @@ class GemmaDecoderLayer(nn.Module):
|
|||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
_length=_length,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
@ -837,6 +845,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -915,6 +924,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
_length=_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -1070,6 +1080,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
|
@ -1114,6 +1125,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
_length=_length,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -1152,6 +1164,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
use_cache=True,
|
||||
_length=None,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
|
@ -1218,6 +1231,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
"_length": int(cache_position[-1]) + 1,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
|
Loading…
Reference in New Issue