Fix decoding score comparison when using logits processors or warpers (#10638)
* Normalize using a logits warper * Add a flag in `generate` to support the logit renormalization * Add in RAG
This commit is contained in:
parent
eb5bdcdfa5
commit
f7196f2e63
|
@ -679,3 +679,16 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||
self.regulation_factor, cur_len - self.regulation_start
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class LogitNormalization(LogitsProcessor, LogitsWarper):
|
||||
r"""
|
||||
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
|
||||
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
|
||||
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
|
||||
the scores are normalized when comparing the hypotheses.
|
||||
"""
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores = scores.log_softmax(dim=-1)
|
||||
return scores
|
||||
|
|
|
@ -32,6 +32,7 @@ from .generation_logits_process import (
|
|||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
|
@ -636,6 +637,7 @@ class GenerationMixin:
|
|||
typical_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
renormalize_logits: Optional[bool] = None,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
|
@ -660,6 +662,9 @@ class GenerationMixin:
|
|||
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if renormalize_logits is True:
|
||||
warpers.append(LogitNormalization())
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
|
@ -682,6 +687,7 @@ class GenerationMixin:
|
|||
remove_invalid_values: bool,
|
||||
exponential_decay_length_penalty: Tuple,
|
||||
logits_processor: Optional[LogitsProcessorList],
|
||||
renormalize_logits: Optional[bool],
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
||||
|
@ -754,6 +760,9 @@ class GenerationMixin:
|
|||
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
|
||||
)
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if renormalize_logits is True:
|
||||
processors.append(LogitNormalization())
|
||||
return processors
|
||||
|
||||
def _get_stopping_criteria(
|
||||
|
@ -858,6 +867,7 @@ class GenerationMixin:
|
|||
diversity_penalty: Optional[float] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||
renormalize_logits: Optional[bool] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
||||
constraints: Optional[List[Constraint]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
|
@ -986,6 +996,10 @@ class GenerationMixin:
|
|||
Custom logits processors that complement the default logits processors built from arguments and a
|
||||
model's config. If a logit processor is passed that is already created with the arguments or a model's
|
||||
config an error is thrown. This feature is intended for advanced users.
|
||||
renormalize_logits: (`bool`, *optional*, defaults to `False`):
|
||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the
|
||||
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
||||
score logits are normalized but some logit processors or warpers break the normalization.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
||||
model's config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
|
@ -1241,6 +1255,7 @@ class GenerationMixin:
|
|||
remove_invalid_values=remove_invalid_values,
|
||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||
logits_processor=logits_processor,
|
||||
renormalize_logits=renormalize_logits,
|
||||
)
|
||||
|
||||
# 8. prepare stopping criteria
|
||||
|
@ -1271,7 +1286,12 @@ class GenerationMixin:
|
|||
elif is_sample_gen_mode:
|
||||
# 10. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(
|
||||
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
typical_p=typical_p,
|
||||
temperature=temperature,
|
||||
num_beams=num_beams,
|
||||
renormalize_logits=renormalize_logits,
|
||||
)
|
||||
|
||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
|
@ -1333,7 +1353,12 @@ class GenerationMixin:
|
|||
elif is_beam_sample_gen_mode:
|
||||
# 10. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(
|
||||
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
typical_p=typical_p,
|
||||
temperature=temperature,
|
||||
num_beams=num_beams,
|
||||
renormalize_logits=renormalize_logits,
|
||||
)
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
|
|
|
@ -1400,6 +1400,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
n_docs: Optional[int] = None,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||
renormalize_logits: Optional[bool] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
|
@ -1624,6 +1625,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
remove_invalid_values=remove_invalid_values,
|
||||
exponential_decay_length_penalty=exponential_decay_length_penalty,
|
||||
logits_processor=logits_processor,
|
||||
renormalize_logits=renormalize_logits,
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
|
|
|
@ -33,6 +33,7 @@ if is_torch_available():
|
|||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
|
@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
|
||||
).all()
|
||||
)
|
||||
|
||||
def test_normalization(self):
|
||||
input_ids = None
|
||||
|
||||
scores = torch.tensor(
|
||||
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
|
||||
logit_normalization = LogitNormalization()
|
||||
normalized_scores = logit_normalization(input_ids, scores).exp()
|
||||
|
||||
ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
|
||||
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||
|
||||
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||
|
|
Loading…
Reference in New Issue