Fix on "cache position" for assisted generation (#30068)

* clean commit history I hope

* get kv seq length correctly

* PR suggestions

* Update src/transformers/testing_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add comment

* give gpt bigcode it's own overriden method

* remove code

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay 2024-04-23 16:23:36 +05:00 committed by GitHub
parent 31921d8d5e
commit 77b59dce9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 38 deletions

View File

@ -641,6 +641,7 @@ class GenerationMixin:
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
@ -671,7 +672,7 @@ class GenerationMixin:
)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs
@ -1294,6 +1295,21 @@ class GenerationMixin:
return generation_config, model_kwargs
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
past_length = 0
if "past_key_values" in model_kwargs:
if isinstance(model_kwargs["past_key_values"], Cache):
past_length = model_kwargs["past_key_values"].get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs
@torch.no_grad()
def generate(
self,
@ -1560,6 +1576,8 @@ class GenerationMixin:
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation == "static":
raise ValueError("assisted generate is not supported with `static_cache`")
# 11. Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
@ -2024,11 +2042,9 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
batch_size = input_ids.shape[0]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
this_peer_finished = False
@ -2495,12 +2511,10 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
@ -2792,12 +2806,10 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
batch_size = input_ids.shape[0]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
@ -3108,9 +3120,7 @@ class GenerationMixin:
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
@ -3514,9 +3524,7 @@ class GenerationMixin:
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
@ -3874,9 +3882,7 @@ class GenerationMixin:
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
@ -4292,9 +4298,7 @@ class GenerationMixin:
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
@ -4655,11 +4659,9 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
batch_size = input_ids.shape[0]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
@ -4679,20 +4681,21 @@ class GenerationMixin:
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Prepare the model inputs
model_kwargs = _prepare_attention_mask(
model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = _prepare_attention_mask(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1])
if "cache_position" in model_kwargs:
model_kwargs["cache_position"] = torch.cat(
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
model_kwargs["cache_position"],
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
if "num_logits_to_keep" in model_inputs:
model_inputs["num_logits_to_keep"] = candidate_length + 1
@ -4811,6 +4814,7 @@ class GenerationMixin:
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
num_new_tokens=n_matches + 1,
)
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)

View File

@ -1209,6 +1209,24 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
)
return model_inputs
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
"""
past_length = 0
if "past_key_values" in model_kwargs:
if self.config.multi_query:
past_length = model_kwargs["past_key_values"][0].shape[1]
else:
past_length = model_kwargs["past_key_values"][0].shape[2]
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
return model_kwargs
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,

View File

@ -231,6 +231,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
conv_kernel_size = config.mamba_d_conv
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
for i in range(config.num_hidden_layers):
if self.layers_block_type[i] == "mamba":
self.conv_states += [
@ -242,6 +243,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
else:
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
self.transformer_layers.append(i)
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
@ -276,6 +278,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
device = self.ssm_states[layer_idx].device
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# take any layer that contains cache and not empty tensor
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")

View File

@ -1091,8 +1091,9 @@ class GenerationTesterMixin:
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
@parameterized.expand([("random",), ("same",)])
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
def test_assisted_decoding_matches_greedy_search(self):
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
# shape differences -- and it may result in a different output. The input shape difference happens in the
@ -1151,7 +1152,13 @@ class GenerationTesterMixin:
}
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
assistant_model = model
# test with the same assistant model or randomly init one
# in the first case all candidate tokens are accepted, in the second none is accepted
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
if assistant_type == "random":
assistant_model = model_class(config).to(torch_device).eval()
else:
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs.update({"assistant_model": assistant_model})