Fix on "cache position" for assisted generation (#30068)
* clean commit history I hope * get kv seq length correctly * PR suggestions * Update src/transformers/testing_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add comment * give gpt bigcode it's own overriden method * remove code --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
31921d8d5e
commit
77b59dce9f
|
@ -641,6 +641,7 @@ class GenerationMixin:
|
||||||
model_kwargs: Dict[str, Any],
|
model_kwargs: Dict[str, Any],
|
||||||
is_encoder_decoder: bool = False,
|
is_encoder_decoder: bool = False,
|
||||||
standardize_cache_format: bool = False,
|
standardize_cache_format: bool = False,
|
||||||
|
num_new_tokens: int = 1,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# update past_key_values
|
# update past_key_values
|
||||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||||
|
@ -671,7 +672,7 @@ class GenerationMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
|
||||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
@ -1294,6 +1295,21 @@ class GenerationMixin:
|
||||||
|
|
||||||
return generation_config, model_kwargs
|
return generation_config, model_kwargs
|
||||||
|
|
||||||
|
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||||
|
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
|
||||||
|
past_length = 0
|
||||||
|
if "past_key_values" in model_kwargs:
|
||||||
|
if isinstance(model_kwargs["past_key_values"], Cache):
|
||||||
|
past_length = model_kwargs["past_key_values"].get_seq_length()
|
||||||
|
else:
|
||||||
|
past_length = model_kwargs["past_key_values"][0][0].shape[2]
|
||||||
|
if "inputs_embeds" in model_kwargs:
|
||||||
|
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||||
|
else:
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
@ -1560,6 +1576,8 @@ class GenerationMixin:
|
||||||
raise ValueError("assisted generate is only supported for batch_size = 1")
|
raise ValueError("assisted generate is only supported for batch_size = 1")
|
||||||
if not model_kwargs["use_cache"]:
|
if not model_kwargs["use_cache"]:
|
||||||
raise ValueError("assisted generate requires `use_cache=True`")
|
raise ValueError("assisted generate requires `use_cache=True`")
|
||||||
|
if generation_config.cache_implementation == "static":
|
||||||
|
raise ValueError("assisted generate is not supported with `static_cache`")
|
||||||
|
|
||||||
# 11. Get the candidate generator, given the parameterization
|
# 11. Get the candidate generator, given the parameterization
|
||||||
candidate_generator = self._get_candidate_generator(
|
candidate_generator = self._get_candidate_generator(
|
||||||
|
@ -2024,11 +2042,9 @@ class GenerationMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size = input_ids.shape[0]
|
||||||
if "inputs_embeds" in model_kwargs:
|
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
|
|
||||||
this_peer_finished = False
|
this_peer_finished = False
|
||||||
|
|
||||||
|
@ -2495,12 +2511,10 @@ class GenerationMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size = input_ids.shape[0]
|
||||||
if "inputs_embeds" in model_kwargs:
|
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
this_peer_finished = False
|
this_peer_finished = False
|
||||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
|
|
||||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||||
# prepare model inputs
|
# prepare model inputs
|
||||||
|
@ -2792,12 +2806,10 @@ class GenerationMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size = input_ids.shape[0]
|
||||||
if "inputs_embeds" in model_kwargs:
|
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
this_peer_finished = False
|
this_peer_finished = False
|
||||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
|
|
||||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||||
# prepare model inputs
|
# prepare model inputs
|
||||||
|
@ -3108,9 +3120,7 @@ class GenerationMixin:
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
if "inputs_embeds" in model_kwargs:
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
|
||||||
|
|
||||||
if num_beams * batch_size != batch_beam_size:
|
if num_beams * batch_size != batch_beam_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -3514,9 +3524,7 @@ class GenerationMixin:
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
if "inputs_embeds" in model_kwargs:
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
|
@ -3874,9 +3882,7 @@ class GenerationMixin:
|
||||||
device = input_ids.device
|
device = input_ids.device
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
if "inputs_embeds" in model_kwargs:
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
|
||||||
|
|
||||||
if return_dict_in_generate and output_scores:
|
if return_dict_in_generate and output_scores:
|
||||||
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
|
||||||
|
@ -4292,9 +4298,7 @@ class GenerationMixin:
|
||||||
num_beams = constrained_beam_scorer.num_beams
|
num_beams = constrained_beam_scorer.num_beams
|
||||||
|
|
||||||
batch_beam_size, cur_len = input_ids.shape
|
batch_beam_size, cur_len = input_ids.shape
|
||||||
if "inputs_embeds" in model_kwargs:
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
|
||||||
|
|
||||||
if num_beams * batch_size != batch_beam_size:
|
if num_beams * batch_size != batch_beam_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -4655,11 +4659,9 @@ class GenerationMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size = input_ids.shape[0]
|
||||||
if "inputs_embeds" in model_kwargs:
|
|
||||||
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
|
||||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||||
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||||
|
|
||||||
this_peer_finished = False
|
this_peer_finished = False
|
||||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||||
|
@ -4679,20 +4681,21 @@ class GenerationMixin:
|
||||||
# we use this forward pass to also pick the subsequent logits in the original model.
|
# we use this forward pass to also pick the subsequent logits in the original model.
|
||||||
|
|
||||||
# 2.1. Prepare the model inputs
|
# 2.1. Prepare the model inputs
|
||||||
model_kwargs = _prepare_attention_mask(
|
candidate_kwargs = copy.copy(model_kwargs)
|
||||||
model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
candidate_kwargs = _prepare_attention_mask(
|
||||||
|
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
|
||||||
)
|
)
|
||||||
model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1])
|
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
|
||||||
if "cache_position" in model_kwargs:
|
if "cache_position" in candidate_kwargs:
|
||||||
model_kwargs["cache_position"] = torch.cat(
|
candidate_kwargs["cache_position"] = torch.cat(
|
||||||
(
|
(
|
||||||
model_kwargs["cache_position"],
|
candidate_kwargs["cache_position"],
|
||||||
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
|
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
|
||||||
),
|
),
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
|
||||||
if "num_logits_to_keep" in model_inputs:
|
if "num_logits_to_keep" in model_inputs:
|
||||||
model_inputs["num_logits_to_keep"] = candidate_length + 1
|
model_inputs["num_logits_to_keep"] = candidate_length + 1
|
||||||
|
|
||||||
|
@ -4811,6 +4814,7 @@ class GenerationMixin:
|
||||||
outputs,
|
outputs,
|
||||||
model_kwargs,
|
model_kwargs,
|
||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
|
num_new_tokens=n_matches + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||||
|
|
|
@ -1209,6 +1209,24 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||||
|
"""
|
||||||
|
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
|
||||||
|
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
|
||||||
|
"""
|
||||||
|
past_length = 0
|
||||||
|
if "past_key_values" in model_kwargs:
|
||||||
|
if self.config.multi_query:
|
||||||
|
past_length = model_kwargs["past_key_values"][0].shape[1]
|
||||||
|
else:
|
||||||
|
past_length = model_kwargs["past_key_values"][0].shape[2]
|
||||||
|
if "inputs_embeds" in model_kwargs:
|
||||||
|
cur_len = model_kwargs["inputs_embeds"].shape[1]
|
||||||
|
else:
|
||||||
|
cur_len = input_ids.shape[-1]
|
||||||
|
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
|
|
@ -231,6 +231,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||||
conv_kernel_size = config.mamba_d_conv
|
conv_kernel_size = config.mamba_d_conv
|
||||||
self.conv_states = []
|
self.conv_states = []
|
||||||
self.ssm_states = []
|
self.ssm_states = []
|
||||||
|
self.transformer_layers = []
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
if self.layers_block_type[i] == "mamba":
|
if self.layers_block_type[i] == "mamba":
|
||||||
self.conv_states += [
|
self.conv_states += [
|
||||||
|
@ -242,6 +243,7 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||||
else:
|
else:
|
||||||
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
||||||
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
|
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
|
||||||
|
self.transformer_layers.append(i)
|
||||||
|
|
||||||
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||||
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
||||||
|
@ -276,6 +278,14 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||||
device = self.ssm_states[layer_idx].device
|
device = self.ssm_states[layer_idx].device
|
||||||
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
|
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
# take any layer that contains cache and not empty tensor
|
||||||
|
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
return 0
|
||||||
|
return self.key_cache[layer_idx].shape[-2]
|
||||||
|
|
||||||
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
||||||
|
|
||||||
|
|
|
@ -1091,8 +1091,9 @@ class GenerationTesterMixin:
|
||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
# NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul
|
||||||
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
# shape differences -- and it may result in a different output. The input shape difference happens in the
|
||||||
|
@ -1151,7 +1152,13 @@ class GenerationTesterMixin:
|
||||||
}
|
}
|
||||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
|
|
||||||
assistant_model = model
|
# test with the same assistant model or randomly init one
|
||||||
|
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||||
|
# case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :)
|
||||||
|
if assistant_type == "random":
|
||||||
|
assistant_model = model_class(config).to(torch_device).eval()
|
||||||
|
else:
|
||||||
|
assistant_model = model
|
||||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||||
generation_kwargs.update({"assistant_model": assistant_model})
|
generation_kwargs.update({"assistant_model": assistant_model})
|
||||||
|
|
Loading…
Reference in New Issue