Fix decoding score comparison when using logits processors or warpers (#10638)

* Normalize using a logits warper

* Add a flag in `generate` to support the logit renormalization

* Add in RAG
This commit is contained in:
Santiago Castro 2022-04-13 04:37:33 -04:00 committed by GitHub
parent eb5bdcdfa5
commit f7196f2e63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 2 deletions

View File

@ -679,3 +679,16 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
self.regulation_factor, cur_len - self.regulation_start
)
return scores
class LogitNormalization(LogitsProcessor, LogitsWarper):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.
"""
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores.log_softmax(dim=-1)
return scores

View File

@ -32,6 +32,7 @@ from .generation_logits_process import (
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@ -636,6 +637,7 @@ class GenerationMixin:
typical_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
renormalize_logits: Optional[bool] = None,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
@ -660,6 +662,9 @@ class GenerationMixin:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def _get_logits_processor(
@ -682,6 +687,7 @@ class GenerationMixin:
remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
renormalize_logits: Optional[bool],
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
@ -754,6 +760,9 @@ class GenerationMixin:
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
)
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
processors.append(LogitNormalization())
return processors
def _get_stopping_criteria(
@ -858,6 +867,7 @@ class GenerationMixin:
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
@ -986,6 +996,10 @@ class GenerationMixin:
Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown. This feature is intended for advanced users.
renormalize_logits: (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a
@ -1241,6 +1255,7 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
)
# 8. prepare stopping criteria
@ -1271,7 +1286,12 @@ class GenerationMixin:
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
@ -1333,7 +1353,12 @@ class GenerationMixin:
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
if stopping_criteria.max_length is None:

View File

@ -1400,6 +1400,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs: Optional[int] = None,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
@ -1624,6 +1625,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
)
if num_beams == 1:

View File

@ -33,6 +33,7 @@ if is_torch_available():
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
@ -537,3 +538,18 @@ class LogitsProcessorTest(unittest.TestCase):
scores_after_start[penalty_start + 1 :, eos_token_id], scores[penalty_start + 1 :, eos_token_id]
).all()
)
def test_normalization(self):
input_ids = None
scores = torch.tensor(
[[-23.18, -29.96, -43.54, 47.77], [-33.58, -26.87, -32.96, 22.51]], device=torch_device, dtype=torch.float
)
logit_normalization = LogitNormalization()
normalized_scores = logit_normalization(input_ids, scores).exp()
ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float)
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))