From 8a8ce1cc72798b4dca41b0a8db904f1cabfb3d55 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Sun, 19 May 2024 13:11:10 +0200 Subject: [PATCH] 0002 - Avoid passing cache around methods (so far only for `StaticCache`) --- .../models/gemma/modeling_gemma.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index cb4fe91eb7..bf9a5b6d6b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -520,6 +520,8 @@ class GemmaSdpaAttention(GemmaAttention): cache_position: Optional[torch.LongTensor] = None, _length: int = 0, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + past_key_value = getattr(self, "past_key_values", past_key_value) + if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -530,6 +532,7 @@ class GemmaSdpaAttention(GemmaAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, + # We need to avoid passing this. past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -591,6 +594,7 @@ class GemmaSdpaAttention(GemmaAttention): attn_output = self.o_proj(attn_output) + # TODO: We need to avoid passing this. return attn_output, None, past_key_value @@ -854,6 +858,14 @@ class GemmaModel(GemmaPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # TODO: The condition needs to be improved. + # So far, let's focus on `self.past_key_values` (if it is set, it means we are doing something special!) + past_key_values = getattr(self, "past_key_values", past_key_values) + + # Set `past_key_values` to the attention layers, so we avoid passing this. + for layer in self.layers: + layer.self_attn.past_key_values = past_key_values + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" @@ -882,8 +894,10 @@ class GemmaModel(GemmaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) + # Avoid passing `past_key_values` if we already have it embedded. + _past_key_values = None if hasattr(self, "past_key_values") else past_key_values causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_position, _past_key_values, output_attentions ) # embed positions @@ -910,6 +924,7 @@ class GemmaModel(GemmaPreTrainedModel): hidden_states, causal_mask, position_ids, + # TODO: We need to avoid passing this. past_key_values, output_attentions, use_cache, @@ -920,7 +935,8 @@ class GemmaModel(GemmaPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + # TODO: We need to avoid passing this. + past_key_value=_past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -949,6 +965,7 @@ class GemmaModel(GemmaPreTrainedModel): return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, + # TODO: We need to avoid passing this. past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, @@ -962,6 +979,10 @@ class GemmaModel(GemmaPreTrainedModel): past_key_values: Cache, output_attentions: bool, ): + # TODO: The condition needs to be improved. + # So far, let's focus on `self.past_key_values` (if it is set, it means we are doing something special!) + past_key_values = getattr(self, "past_key_values", past_key_values) + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using @@ -1151,6 +1172,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): return CausalLMOutputWithPast( loss=loss, logits=logits, + # TODO: avoid passing it past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1235,6 +1257,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel): elif use_cache: cache_position = cache_position[-input_length:] + # -------------------------------------------------------------------------------------------------------------- + # TODO: So far we have to set `self.model.past_key_values` in each call to `prepare_inputs_for_generation`, + # as `past_key_values` object may be already being another newly created object. + # Attach the cache object to the model instance. + self.model.past_key_values = past_key_values + # Set `past_key_values` to `None` to avoid the overhead of passing large tensors. + past_key_values = None + # -------------------------------------------------------------------------------------------------------------- + model_inputs.update( { "position_ids": position_ids,