Fix on "cache position" for assisted generation (#30068)
* clean commit history I hope * get kv seq length correctly * PR suggestions * Update src/transformers/testing_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add comment * give gpt bigcode it's own overriden method * remove code --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
31921d8d5e
commit
77b59dce9f
|
@ -641,6 +641,7 @@ class GenerationMixin:
|
|||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
num_new_tokens: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
|
@ -671,7 +672,7 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
@ -1294,6 +1295,21 @@ class GenerationMixin:
|
|||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||
past_length = 0
|
||||
if "past_key_values" in model_kwargs:
|
||||
if isinstance(model_kwargs["past_key_values"], Cache):
|
||||
past_length = model_kwargs["past_key_values"].get_seq_length()
|
||||
else:
|
||||
past_length = model_kwargs["past_key_values"][0][0].shape[2]
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
else:
|
||||
cur_len = input_ids.shape[-1]
|
||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||
return model_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
|
@ -1560,6 +1576,8 @@ class GenerationMixin:
|
|||
raise ValueError("assisted generate is only supported for batch_size = 1")
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("assisted generate requires `use_cache=True`")
|
||||
if generation_config.cache_implementation == "static":
|
||||
raise ValueError("assisted generate is not supported with `static_cache`")
|
||||
|
||||
# 11. Get the candidate generator, given the parameterization
|
||||
candidate_generator = self._get_candidate_generator(
|
||||
|
@ -2024,11 +2042,9 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
batch_size = input_ids.shape[0]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
|
@ -2495,12 +2511,10 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
batch_size = input_ids.shape[0]
|
||||
this_peer_finished = False
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
|
@ -2792,12 +2806,10 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
batch_size = input_ids.shape[0]
|
||||
this_peer_finished = False
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
|
@ -3108,9 +3120,7 @@ class GenerationMixin:
|
|||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
|
@ -3514,9 +3524,7 @@ class GenerationMixin:
|
|||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
|
@ -3874,9 +3882,7 @@ class GenerationMixin:
|
|||
device = input_ids.device
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
||||
|
@ -4292,9 +4298,7 @@ class GenerationMixin:
|
|||
num_beams = constrained_beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
|
@ -4655,11 +4659,9 @@ class GenerationMixin:
|
|||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size, cur_len = input_ids.shape
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
batch_size = input_ids.shape[0]
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
this_peer_finished = False
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
|
@ -4679,20 +4681,21 @@ class GenerationMixin:
|
|||
# we use this forward pass to also pick the subsequent logits in the original model.
|
||||
|
||||
# 2.1. Prepare the model inputs
|
||||
model_kwargs = _prepare_attention_mask(
|
||||
model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
||||
candidate_kwargs = copy.copy(model_kwargs)
|
||||
candidate_kwargs = _prepare_attention_mask(
|
||||
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
||||
)
|
||||
model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1])
|
||||
if "cache_position" in model_kwargs:
|
||||
model_kwargs["cache_position"] = torch.cat(
|
||||
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||
if "cache_position" in candidate_kwargs:
|
||||
candidate_kwargs["cache_position"] = torch.cat(
|
||||
(
|
||||
model_kwargs["cache_position"],
|
||||
candidate_kwargs["cache_position"],
|
||||
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs)
|
||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
||||
if "num_logits_to_keep" in model_inputs:
|
||||
model_inputs["num_logits_to_keep"] = candidate_length + 1
|
||||
|
||||
|
@ -4811,6 +4814,7 @@ class GenerationMixin:
|
|||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
num_new_tokens=n_matches + 1,
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
|
|
|
@ -1209,6 +1209,24 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
|||
)
|
||||
return model_inputs
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
"""
|
||||
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
|
||||
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
|
||||
"""
|
||||
past_length = 0
|
||||
if "past_key_values" in model_kwargs:
|
||||
if self.config.multi_query:
|
||||
past_length = model_kwargs["past_key_values"][0].shape[1]
|
||||
else:
|
||||
past_length = model_kwargs["past_key_values"][0].shape[2]
|
||||
if "inputs_embeds" in model_kwargs:
|
||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||
else:
|
||||
cur_len = input_ids.shape[-1]
|
||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||
return model_kwargs
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
|
|
|
@ -231,6 +231,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|||
conv_kernel_size = config.mamba_d_conv
|
||||
self.conv_states = []
|
||||
self.ssm_states = []
|
||||
self.transformer_layers = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
if self.layers_block_type[i] == "mamba":
|
||||
self.conv_states += [
|
||||
|
@ -242,6 +243,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|||
else:
|
||||
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
||||
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
|
||||
self.transformer_layers.append(i)
|
||||
|
||||
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||
|
@ -276,6 +278,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|||
device = self.ssm_states[layer_idx].device
|
||||
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
# take any layer that contains cache and not empty tensor
|
||||
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
return 0
|
||||
return self.key_cache[layer_idx].shape[-2]
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
||||
|
||||
|
|
|
@ -1091,8 +1091,9 @@ class GenerationTesterMixin:
|
|||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
||||
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
||||
|
@ -1151,7 +1152,13 @@ class GenerationTesterMixin:
|
|||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
|
||||
assistant_model = model
|
||||
# test with the same assistant model or randomly init one
|
||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
|
||||
if assistant_type == "random":
|
||||
assistant_model = model_class(config).to(torch_device).eval()
|
||||
else:
|
||||
assistant_model = model
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
|
|
Loading…
Reference in New Issue