Docs: Update logit processors __call__ docs (#24729)

* tmp commit

* __call__ docs

* kwargs documented; shorter input_ids doc

* nit

* Update src/transformers/generation/logits_process.py
This commit is contained in:
Joao Gante 2023-07-12 12:21:30 +01:00 committed by GitHub
parent 6e2f069650
commit 430a04a75a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 19 deletions

View File

@ -30,17 +30,10 @@ logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
kwargs (`Dict[str, Any]`, *optional*):
Additional logits processor specific kwargs.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
@ -53,7 +46,6 @@ class LogitsProcessor:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@ -64,7 +56,6 @@ class LogitsWarper:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
@ -77,8 +68,22 @@ class LogitsProcessorList(list):
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
kwargs (`Dict[str, Any]`, *optional*):
Additional kwargs that are specific to a logits processor.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
@ -116,6 +121,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
self.min_length = min_length
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
@ -154,6 +160,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
self.min_new_tokens = min_new_tokens
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
if new_tokens_length < self.min_new_tokens:
@ -178,7 +185,8 @@ class TemperatureLogitsWarper(LogitsWarper):
self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores / self.temperature
return scores
@ -199,6 +207,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, input_ids)
@ -227,6 +236,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = 1 / penalty
self.encoder_input_ids = encoder_input_ids
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
score = torch.gather(scores, 1, self.encoder_input_ids)
@ -262,6 +272,7 @@ class TopPLogitsWarper(LogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
@ -297,6 +308,7 @@ class TopKLogitsWarper(LogitsWarper):
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
@ -330,6 +342,7 @@ class TypicalLogitsWarper(LogitsWarper):
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
@ -383,6 +396,7 @@ class EpsilonLogitsWarper(LogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Determine which indices to remove
probabilities = scores.softmax(dim=-1)
@ -422,6 +436,7 @@ class EtaLogitsWarper(LogitsWarper):
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Calculate the adaptive cutoff
probabilities = scores.softmax(dim=-1)
@ -487,6 +502,7 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
@ -521,6 +537,7 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
self.batch_size = encoder_input_ids.shape[0]
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# B x num_beams
num_hypos = scores.shape[0]
@ -612,6 +629,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
self.length_greather_than_1_bias = None
self.prepared_bias_variables = False
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
if not self.prepared_bias_variables:
@ -774,6 +792,7 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
@ -821,6 +840,23 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
current_tokens: torch.LongTensor,
beam_group_idx: int,
) -> torch.FloatTensor:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
beam groups in the current generation step.
beam_group_idx (`int`):
The index of the beam group currently being processed.
Return:
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
The processed prediction scores.
"""
# hamming diversity: penalise using same token in current group which was used in previous groups at
# the same time step
batch_size = current_tokens.shape[0] // self._num_beams
@ -855,6 +891,7 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == 1:
@ -882,6 +919,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
@ -898,6 +936,7 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
the logits processor should only be used if necessary since it can slow down the generation method.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# set all nan values to 0.0
scores[scores != scores] = 0.0
@ -935,7 +974,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
eos_token_id = [eos_token_id]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
for i in self.eos_token_id:
@ -951,7 +991,8 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
the scores are normalized when comparing the hypotheses.
"""
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = scores.log_softmax(dim=-1)
return scores
@ -967,7 +1008,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index
def __call__(self, input_ids, scores):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
scores[:, self.begin_suppress_tokens] = -float("inf")
@ -981,7 +1023,8 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
def __call__(self, input_ids, scores):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores[:, self.suppress_tokens] = -float("inf")
return scores
@ -994,7 +1037,8 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
def __init__(self, force_token_map: List[List[int]]):
self.force_token_map = dict(force_token_map)
def __call__(self, input_ids, scores):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
generation_idx = input_ids.shape[-1]
current_token = self.force_token_map.get(generation_idx, None)
if current_token is not None:
@ -1030,7 +1074,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
def __call__(self, input_ids, scores):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
@ -1089,7 +1134,8 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
f"{guidance_scale}."
)
def __call__(self, input_ids, scores):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# simple check to make sure we have compatible batch sizes between our
# logits scores (cond + uncond) and input ids (cond only)
if scores.shape[0] != 2 * input_ids.shape[0]: