From cc4a664baaac790aadc4ca9c5d93031893432433 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 21 Feb 2024 12:19:30 +0100 Subject: [PATCH] `torch.compile` compatibility with `generate` + static cache (#29114) * fix compatibility * working version * cleanup * sanity checks * more sanity * working version WITH refactor * working without API change * cleanup & tests pass * more cleaning * fix test * fix tests * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * smaller comment * update comment * update comment --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/cache_utils.py | 12 ++- src/transformers/generation/utils.py | 73 ++++++++++++------- .../models/llama/modeling_llama.py | 42 ++++++----- 3 files changed, 76 insertions(+), 51 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index abdc3c7c07..1cb7c429ae 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -357,7 +357,6 @@ class StaticCache(Cache): cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.seen_tokens = 0 def update( self, @@ -391,15 +390,20 @@ class StaticCache(Cache): k_out[:, :, new_cache_positions] = key_states v_out[:, :, new_cache_positions] = value_states - self.seen_tokens += key_states.shape[2] return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - return self.seen_tokens + # TODO: Fix once the stateful `int` bug in PyTorch is fixed. + raise ValueError( + "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." + ) def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: - return self.seen_tokens + # TODO: Fix once the stateful `int` bug in PyTorch is fixed. + raise ValueError( + "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." + ) def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 08fde58507..d337e55934 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -648,6 +648,7 @@ class GenerationMixin: model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, + model_inputs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -677,6 +678,8 @@ class GenerationMixin: dim=-1, ) + model_kwargs["cache_position"] = model_inputs.get("cache_position", None) + return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -1451,17 +1454,19 @@ class GenerationMixin: ): generation_config.max_length -= inputs_tensor.shape[1] - # if we don't pass `past_key_values` and a cache_implementation is specified - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get( - "past_key_values", False - ): - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static": + if model_kwargs.get("past_key_values", False) is not False: + raise ValueError( + "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." + ) + cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] + if not callable(getattr(self, "_setup_cache", None)): + raise ValueError( + "The `generation_config` defines a `cache_implementation` that is not compatible with this model." + " Make sure it has a `_setup_cache` function." + ) + self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -1523,7 +1528,7 @@ class GenerationMixin: ) # 12. run assisted generate - return self.assisted_decoding( + result = self.assisted_decoding( input_ids, candidate_generator=candidate_generator, do_sample=generation_config.do_sample, @@ -1541,7 +1546,7 @@ class GenerationMixin: ) if generation_mode == GenerationMode.GREEDY_SEARCH: # 11. run greedy search - return self.greedy_search( + result = self.greedy_search( input_ids, logits_processor=prepared_logits_processor, stopping_criteria=prepared_stopping_criteria, @@ -1559,7 +1564,7 @@ class GenerationMixin: if not model_kwargs["use_cache"]: raise ValueError("Contrastive search requires `use_cache=True`") - return self.contrastive_search( + result = self.contrastive_search( input_ids, top_k=generation_config.top_k, penalty_alpha=generation_config.penalty_alpha, @@ -1589,7 +1594,7 @@ class GenerationMixin: ) # 13. run sample - return self.sample( + result = self.sample( input_ids, logits_processor=prepared_logits_processor, logits_warper=logits_warper, @@ -1623,7 +1628,7 @@ class GenerationMixin: **model_kwargs, ) # 13. run beam search - return self.beam_search( + result = self.beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1662,7 +1667,7 @@ class GenerationMixin: ) # 14. run beam sample - return self.beam_sample( + result = self.beam_sample( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1697,7 +1702,7 @@ class GenerationMixin: **model_kwargs, ) # 13. run beam search - return self.group_beam_search( + result = self.group_beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, @@ -1771,7 +1776,7 @@ class GenerationMixin: **model_kwargs, ) # 13. run beam search - return self.constrained_beam_search( + result = self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=prepared_logits_processor, @@ -1785,6 +1790,16 @@ class GenerationMixin: **model_kwargs, ) + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not callable(getattr(self, "_reset_cache", None)): + raise ValueError( + "A `static_cache` was used to generate but there was a failure when trying to release the cache. " + " Make sure this model implements a `_reset_cache` function." + ) + self._reset_cache() + + return result + @torch.no_grad() def contrastive_search( self, @@ -1975,6 +1990,7 @@ class GenerationMixin: model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, + model_inputs=model_inputs, ) if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). @@ -2169,7 +2185,7 @@ class GenerationMixin: if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished @@ -2450,7 +2466,10 @@ class GenerationMixin: if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + model_inputs=model_inputs, ) # if eos_token was found in one sentence, set sentence to finished @@ -2744,7 +2763,7 @@ class GenerationMixin: if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished @@ -3137,7 +3156,7 @@ class GenerationMixin: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3484,7 +3503,7 @@ class GenerationMixin: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -3883,7 +3902,7 @@ class GenerationMixin: input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4235,7 +4254,7 @@ class GenerationMixin: input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) if model_kwargs["past_key_values"] is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( @@ -4642,7 +4661,7 @@ class GenerationMixin: ) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) # if eos_token was found in one sentence, set sentence to finished diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5fb7e8459a..8e494adefc 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -641,6 +641,7 @@ class LlamaSdpaAttention(LlamaAttention): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # In case static cache is used, it is an instance attribute. past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: @@ -969,9 +970,11 @@ class LlamaModel(LlamaPreTrainedModel): if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1043,6 +1046,10 @@ class LlamaModel(LlamaPreTrainedModel): attentions=all_self_attns, ) + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -1058,16 +1065,8 @@ class LlamaModel(LlamaPreTrainedModel): causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows - causal_mask = ( - self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min - ) - else: - mask = torch.full( - (self.config.max_position_embeddings, self.config.max_position_embeddings), - fill_value=torch.finfo(dtype).min, - ) - causal_mask = torch.triu(mask, diagonal=1) + # We use the current dtype to avoid any overflows + causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min causal_mask = causal_mask.to(dtype=dtype, device=device) if attention_mask is not None and attention_mask.dim() == 2: @@ -1253,29 +1252,32 @@ class LlamaForCausalLM(LlamaPreTrainedModel): if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: # generation with static cache - past_length = past_key_value.get_seq_length() + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + past_length = 0 + else: + past_length = cache_position[-1] + 1 input_ids = input_ids[:, past_length:] position_ids = position_ids[:, past_length:] # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - cache_position = torch.arange( - past_length, past_length + position_ids.shape[-1], device=position_ids.device - ) + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { - "position_ids": position_ids, + "position_ids": position_ids.contiguous(), "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"),