0002 - Avoid passing cache around methods (so far only for `StaticCache`)

This commit is contained in:
ydshieh 2024-05-19 13:11:10 +02:00
parent 331c2550dd
commit 8a8ce1cc72
1 changed files with 33 additions and 2 deletions

View File

@ -520,6 +520,8 @@ class GemmaSdpaAttention(GemmaAttention):
cache_position: Optional[torch.LongTensor] = None,
_length: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
past_key_value = getattr(self, "past_key_values", past_key_value)
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
@ -530,6 +532,7 @@ class GemmaSdpaAttention(GemmaAttention):
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
# We need to avoid passing this.
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
@ -591,6 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
attn_output = self.o_proj(attn_output)
# TODO: We need to avoid passing this.
return attn_output, None, past_key_value
@ -854,6 +858,14 @@ class GemmaModel(GemmaPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO: The condition needs to be improved.
# So far, let's focus on `self.past_key_values` (if it is set, it means we are doing something special!)
past_key_values = getattr(self, "past_key_values", past_key_values)
# Set `past_key_values` to the attention layers, so we avoid passing this.
for layer in self.layers:
layer.self_attn.past_key_values = past_key_values
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
@ -882,8 +894,10 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# Avoid passing `past_key_values` if we already have it embedded.
_past_key_values = None if hasattr(self, "past_key_values") else past_key_values
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask, inputs_embeds, cache_position, _past_key_values, output_attentions
)
# embed positions
@ -910,6 +924,7 @@ class GemmaModel(GemmaPreTrainedModel):
hidden_states,
causal_mask,
position_ids,
# TODO: We need to avoid passing this.
past_key_values,
output_attentions,
use_cache,
@ -920,7 +935,8 @@ class GemmaModel(GemmaPreTrainedModel):
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
# TODO: We need to avoid passing this.
past_key_value=_past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
@ -949,6 +965,7 @@ class GemmaModel(GemmaPreTrainedModel):
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
# TODO: We need to avoid passing this.
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
@ -962,6 +979,10 @@ class GemmaModel(GemmaPreTrainedModel):
past_key_values: Cache,
output_attentions: bool,
):
# TODO: The condition needs to be improved.
# So far, let's focus on `self.past_key_values` (if it is set, it means we are doing something special!)
past_key_values = getattr(self, "past_key_values", past_key_values)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
@ -1151,6 +1172,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
# TODO: avoid passing it
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
@ -1235,6 +1257,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
elif use_cache:
cache_position = cache_position[-input_length:]
# --------------------------------------------------------------------------------------------------------------
# TODO: So far we have to set `self.model.past_key_values` in each call to `prepare_inputs_for_generation`,
# as `past_key_values` object may be already being another newly created object.
# Attach the cache object to the model instance.
self.model.past_key_values = past_key_values
# Set `past_key_values` to `None` to avoid the overhead of passing large tensors.
past_key_values = None
# --------------------------------------------------------------------------------------------------------------
model_inputs.update(
{
"position_ids": position_ids,