0002 - Avoid passing cache around methods (so far only for `StaticCache`)
This commit is contained in:
parent
331c2550dd
commit
8a8ce1cc72
|
@ -520,6 +520,8 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
_length: int = 0,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
past_key_value = getattr(self, "past_key_values", past_key_value)
|
||||
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
|
@ -530,6 +532,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
# We need to avoid passing this.
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
|
@ -591,6 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
|
|||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# TODO: We need to avoid passing this.
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
|
@ -854,6 +858,14 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO: The condition needs to be improved.
|
||||
# So far, let's focus on `self.past_key_values` (if it is set, it means we are doing something special!)
|
||||
past_key_values = getattr(self, "past_key_values", past_key_values)
|
||||
|
||||
# Set `past_key_values` to the attention layers, so we avoid passing this.
|
||||
for layer in self.layers:
|
||||
layer.self_attn.past_key_values = past_key_values
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
|
@ -882,8 +894,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# Avoid passing `past_key_values` if we already have it embedded.
|
||||
_past_key_values = None if hasattr(self, "past_key_values") else past_key_values
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
attention_mask, inputs_embeds, cache_position, _past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# embed positions
|
||||
|
@ -910,6 +924,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
# TODO: We need to avoid passing this.
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
|
@ -920,7 +935,8 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
# TODO: We need to avoid passing this.
|
||||
past_key_value=_past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
|
@ -949,6 +965,7 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
# TODO: We need to avoid passing this.
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
|
@ -962,6 +979,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
# TODO: The condition needs to be improved.
|
||||
# So far, let's focus on `self.past_key_values` (if it is set, it means we are doing something special!)
|
||||
past_key_values = getattr(self, "past_key_values", past_key_values)
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
|
@ -1151,6 +1172,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
# TODO: avoid passing it
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
|
@ -1235,6 +1257,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
|||
elif use_cache:
|
||||
cache_position = cache_position[-input_length:]
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------
|
||||
# TODO: So far we have to set `self.model.past_key_values` in each call to `prepare_inputs_for_generation`,
|
||||
# as `past_key_values` object may be already being another newly created object.
|
||||
# Attach the cache object to the model instance.
|
||||
self.model.past_key_values = past_key_values
|
||||
# Set `past_key_values` to `None` to avoid the overhead of passing large tensors.
|
||||
past_key_values = None
|
||||
# --------------------------------------------------------------------------------------------------------------
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
|
|
Loading…
Reference in New Issue