|
|
|
@ -75,6 +75,7 @@ from .logits_process import (
|
|
|
|
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
|
|
|
|
)
|
|
|
|
|
from .stopping_criteria import (
|
|
|
|
|
EosTokenCriteria,
|
|
|
|
|
MaxLengthCriteria,
|
|
|
|
|
MaxTimeCriteria,
|
|
|
|
|
StoppingCriteria,
|
|
|
|
@ -690,6 +691,7 @@ class GenerationMixin:
|
|
|
|
|
candidate_generator = PromptLookupCandidateGenerator(
|
|
|
|
|
num_output_tokens=generation_config.prompt_lookup_num_tokens,
|
|
|
|
|
max_matching_ngram_size=generation_config.max_matching_ngram_size,
|
|
|
|
|
max_length=generation_config.max_length,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
candidate_generator = AssistedCandidateGenerator(
|
|
|
|
@ -892,6 +894,8 @@ class GenerationMixin:
|
|
|
|
|
)
|
|
|
|
|
if generation_config.max_time is not None:
|
|
|
|
|
criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
|
|
|
|
|
if generation_config.eos_token_id is not None:
|
|
|
|
|
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
|
|
|
|
|
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
|
|
|
|
|
return criteria
|
|
|
|
|
|
|
|
|
@ -1306,7 +1310,7 @@ class GenerationMixin:
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
|
|
|
|
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
|
|
|
|
|
or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
|
|
|
|
|
|
|
|
|
|
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
|
|
|
|
[`~utils.ModelOutput`] types are:
|
|
|
|
@ -1515,7 +1519,6 @@ class GenerationMixin:
|
|
|
|
|
logits_warper=self._get_logits_warper(generation_config) if generation_config.do_sample else None,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1530,7 +1533,6 @@ class GenerationMixin:
|
|
|
|
|
logits_processor=prepared_logits_processor,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1550,7 +1552,6 @@ class GenerationMixin:
|
|
|
|
|
logits_processor=prepared_logits_processor,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1579,7 +1580,6 @@ class GenerationMixin:
|
|
|
|
|
logits_warper=logits_warper,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1613,7 +1613,6 @@ class GenerationMixin:
|
|
|
|
|
logits_processor=prepared_logits_processor,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1653,7 +1652,6 @@ class GenerationMixin:
|
|
|
|
|
logits_warper=logits_warper,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1687,7 +1685,6 @@ class GenerationMixin:
|
|
|
|
|
logits_processor=prepared_logits_processor,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1761,7 +1758,6 @@ class GenerationMixin:
|
|
|
|
|
logits_processor=prepared_logits_processor,
|
|
|
|
|
stopping_criteria=prepared_stopping_criteria,
|
|
|
|
|
pad_token_id=generation_config.pad_token_id,
|
|
|
|
|
eos_token_id=generation_config.eos_token_id,
|
|
|
|
|
output_scores=generation_config.output_scores,
|
|
|
|
|
output_logits=generation_config.output_logits,
|
|
|
|
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
|
|
|
@ -1916,11 +1912,28 @@ class GenerationMixin:
|
|
|
|
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
|
|
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
|
|
|
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
|
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
|
|
|
|
output_attentions = (
|
|
|
|
@ -2186,12 +2199,6 @@ class GenerationMixin:
|
|
|
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# stop when each sentence is finished
|
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
|
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
|
|
@ -2365,10 +2372,27 @@ class GenerationMixin:
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
|
output_attentions = (
|
|
|
|
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
|
|
|
@ -2463,12 +2487,6 @@ class GenerationMixin:
|
|
|
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
|
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
|
|
|
|
|
|
|
@ -2650,10 +2668,27 @@ class GenerationMixin:
|
|
|
|
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
|
|
|
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
|
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
|
|
|
|
output_attentions = (
|
|
|
|
@ -2751,12 +2786,6 @@ class GenerationMixin:
|
|
|
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
|
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
|
|
|
|
|
|
|
@ -2966,7 +2995,25 @@ class GenerationMixin:
|
|
|
|
|
if len(stopping_criteria) == 0:
|
|
|
|
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private and beam scorer refactored
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
@ -3351,7 +3398,25 @@ class GenerationMixin:
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private and beam scorer refactored
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
@ -3688,7 +3753,25 @@ class GenerationMixin:
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private and beam scorer refactored
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
@ -4089,7 +4172,25 @@ class GenerationMixin:
|
|
|
|
|
if len(stopping_criteria) == 0:
|
|
|
|
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private and beam scorer refactored
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
@ -4421,12 +4522,27 @@ class GenerationMixin:
|
|
|
|
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
|
|
|
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
|
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
|
|
|
|
if eos_token_id is not None and pad_token_id is None:
|
|
|
|
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
|
|
|
if eos_token_id is not None:
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use"
|
|
|
|
|
" `stopping_criteria=StoppingCriteriaList([EosTokenCriteria(eos_token_id=eos_token_id)])` instead."
|
|
|
|
|
" Otherwise make sure to set `model.generation_config.eos_token_id`",
|
|
|
|
|
FutureWarning,
|
|
|
|
|
)
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
else:
|
|
|
|
|
# TODO remove when the method is totally private and beam scorer refactored
|
|
|
|
|
# need to get `eos_token_id` and add stopping criteria, so that generation does not go forever
|
|
|
|
|
eos_token_id = [
|
|
|
|
|
criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id")
|
|
|
|
|
]
|
|
|
|
|
eos_token_id = eos_token_id[0] if eos_token_id else None
|
|
|
|
|
if eos_token_id is None and self.generation_config.eos_token_id is not None:
|
|
|
|
|
eos_token_id = self.generation_config.eos_token_id
|
|
|
|
|
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
|
|
|
|
|
|
|
|
|
|
if isinstance(eos_token_id, int):
|
|
|
|
|
eos_token_id = [eos_token_id]
|
|
|
|
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
|
|
|
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
|
|
|
|
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
|
|
|
|
output_attentions = (
|
|
|
|
@ -4462,9 +4578,6 @@ class GenerationMixin:
|
|
|
|
|
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
|
|
|
|
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
# other auxiliary variables
|
|
|
|
|
max_len = stopping_criteria[0].max_length
|
|
|
|
|
|
|
|
|
|
this_peer_finished = False
|
|
|
|
|
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
|
|
|
cur_len = input_ids.shape[-1]
|
|
|
|
@ -4476,13 +4589,7 @@ class GenerationMixin:
|
|
|
|
|
candidate_logits = candidate_logits.to(self.device)
|
|
|
|
|
|
|
|
|
|
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
|
|
|
|
last_assistant_token_is_eos = (
|
|
|
|
|
~candidate_input_ids[:, -1]
|
|
|
|
|
.tile(eos_token_id_tensor.shape[0], 1)
|
|
|
|
|
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
|
|
|
.prod(dim=0)
|
|
|
|
|
.bool()
|
|
|
|
|
)
|
|
|
|
|
is_done_candidate = stopping_criteria(candidate_input_ids, None)
|
|
|
|
|
|
|
|
|
|
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
|
|
|
|
|
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
|
|
|
|
@ -4525,15 +4632,13 @@ class GenerationMixin:
|
|
|
|
|
# 3. Select the accepted tokens. There are two possible cases:
|
|
|
|
|
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
|
|
|
|
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
|
|
|
|
|
max_matches = max_len - cur_len - 1
|
|
|
|
|
if do_sample and candidate_logits is not None:
|
|
|
|
|
valid_tokens, n_matches = _speculative_sampling(
|
|
|
|
|
candidate_input_ids,
|
|
|
|
|
candidate_logits,
|
|
|
|
|
candidate_length,
|
|
|
|
|
new_logits,
|
|
|
|
|
last_assistant_token_is_eos,
|
|
|
|
|
max_matches,
|
|
|
|
|
is_done_candidate,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
|
|
|
|
@ -4550,9 +4655,8 @@ class GenerationMixin:
|
|
|
|
|
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
|
|
|
|
|
|
|
|
|
|
# Ensure we don't generate beyond max_len or an EOS token
|
|
|
|
|
if last_assistant_token_is_eos and n_matches == candidate_length:
|
|
|
|
|
if is_done_candidate and n_matches == candidate_length:
|
|
|
|
|
n_matches -= 1
|
|
|
|
|
n_matches = min(n_matches, max_matches)
|
|
|
|
|
valid_tokens = selected_tokens[:, : n_matches + 1]
|
|
|
|
|
|
|
|
|
|
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
|
|
|
|
@ -4625,15 +4729,6 @@ class GenerationMixin:
|
|
|
|
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if eos_token was found in one sentence, set sentence to finished
|
|
|
|
|
if eos_token_id_tensor is not None:
|
|
|
|
|
unfinished_sequences = unfinished_sequences.mul(
|
|
|
|
|
input_ids[:, -1]
|
|
|
|
|
.tile(eos_token_id_tensor.shape[0], 1)
|
|
|
|
|
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
|
|
|
.prod(dim=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
|
|
|
this_peer_finished = unfinished_sequences.max() == 0
|
|
|
|
|
|
|
|
|
@ -4678,8 +4773,7 @@ def _speculative_sampling(
|
|
|
|
|
candidate_logits,
|
|
|
|
|
candidate_length,
|
|
|
|
|
new_logits,
|
|
|
|
|
last_assistant_token_is_eos,
|
|
|
|
|
max_matches,
|
|
|
|
|
is_done_candidate,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
|
|
|
|
@ -4704,16 +4798,14 @@ def _speculative_sampling(
|
|
|
|
|
n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
|
|
|
|
|
|
|
|
|
|
# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
|
|
|
|
|
if last_assistant_token_is_eos and n_matches == candidate_length:
|
|
|
|
|
if is_done_candidate and n_matches == candidate_length:
|
|
|
|
|
# Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
|
|
|
|
|
# due to acceptance on EOS we fix `n_matches`
|
|
|
|
|
n_matches -= 1
|
|
|
|
|
valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
|
|
|
|
|
else:
|
|
|
|
|
n_matches = min(n_matches, max_matches)
|
|
|
|
|
|
|
|
|
|
# Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
|
|
|
|
|
gamma = min(candidate_logits.shape[1], max_matches)
|
|
|
|
|
gamma = candidate_logits.shape[1]
|
|
|
|
|
p_n_plus_1 = p[:, n_matches, :]
|
|
|
|
|
if n_matches < gamma:
|
|
|
|
|
q_n_plus_1 = q[:, n_matches, :]
|
|
|
|
|