Docs: Update logit processors __call__ docs (#24729)
* tmp commit * __call__ docs * kwargs documented; shorter input_ids doc * nit * Update src/transformers/generation/logits_process.py
This commit is contained in:
parent
6e2f069650
commit
430a04a75a
|
@ -30,17 +30,10 @@ logger = get_logger(__name__)
|
|||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||
search or log softmax for each vocabulary token when using beam search
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional logits processor specific kwargs.
|
||||
|
||||
Return:
|
||||
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||
|
@ -53,7 +46,6 @@ class LogitsProcessor:
|
|||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Torch method for processing logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
@ -64,7 +56,6 @@ class LogitsWarper:
|
|||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Torch method for warping logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
@ -77,8 +68,22 @@ class LogitsProcessorList(list):
|
|||
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
||||
beam search or log softmax for each vocabulary token when using beam search
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional kwargs that are specific to a logits processor.
|
||||
|
||||
Return:
|
||||
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
|
||||
The processed prediction scores.
|
||||
|
||||
"""
|
||||
for processor in self:
|
||||
function_args = inspect.signature(processor.__call__).parameters
|
||||
if len(function_args) > 2:
|
||||
|
@ -116,6 +121,7 @@ class MinLengthLogitsProcessor(LogitsProcessor):
|
|||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len < self.min_length:
|
||||
|
@ -154,6 +160,7 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||
self.min_new_tokens = min_new_tokens
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||
if new_tokens_length < self.min_new_tokens:
|
||||
|
@ -178,7 +185,8 @@ class TemperatureLogitsWarper(LogitsWarper):
|
|||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
@ -199,6 +207,7 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||
|
||||
self.penalty = penalty
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
|
@ -227,6 +236,7 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||
self.penalty = 1 / penalty
|
||||
self.encoder_input_ids = encoder_input_ids
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
score = torch.gather(scores, 1, self.encoder_input_ids)
|
||||
|
||||
|
@ -262,6 +272,7 @@ class TopPLogitsWarper(LogitsWarper):
|
|||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
@ -297,6 +308,7 @@ class TopKLogitsWarper(LogitsWarper):
|
|||
self.top_k = max(top_k, min_tokens_to_keep)
|
||||
self.filter_value = filter_value
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
top_k = min(self.top_k, scores.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
|
@ -330,6 +342,7 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||
self.mass = mass
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
|
@ -383,6 +396,7 @@ class EpsilonLogitsWarper(LogitsWarper):
|
|||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Determine which indices to remove
|
||||
probabilities = scores.softmax(dim=-1)
|
||||
|
@ -422,6 +436,7 @@ class EtaLogitsWarper(LogitsWarper):
|
|||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Calculate the adaptive cutoff
|
||||
probabilities = scores.softmax(dim=-1)
|
||||
|
@ -487,6 +502,7 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|||
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
||||
self.ngram_size = ngram_size
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
num_batch_hypotheses = scores.shape[0]
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
@ -521,6 +537,7 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
|
|||
self.batch_size = encoder_input_ids.shape[0]
|
||||
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# B x num_beams
|
||||
num_hypos = scores.shape[0]
|
||||
|
@ -612,6 +629,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
|||
self.length_greather_than_1_bias = None
|
||||
self.prepared_bias_variables = False
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
|
||||
if not self.prepared_bias_variables:
|
||||
|
@ -774,6 +792,7 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor):
|
|||
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
|
||||
self._num_beams = num_beams
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.full_like(scores, -math.inf)
|
||||
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
|
||||
|
@ -821,6 +840,23 @@ class HammingDiversityLogitsProcessor(LogitsProcessor):
|
|||
current_tokens: torch.LongTensor,
|
||||
beam_group_idx: int,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||||
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
||||
beam search or log softmax for each vocabulary token when using beam search
|
||||
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
|
||||
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
|
||||
beam groups in the current generation step.
|
||||
beam_group_idx (`int`):
|
||||
The index of the beam group currently being processed.
|
||||
|
||||
Return:
|
||||
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
|
||||
The processed prediction scores.
|
||||
"""
|
||||
# hamming diversity: penalise using same token in current group which was used in previous groups at
|
||||
# the same time step
|
||||
batch_size = current_tokens.shape[0] // self._num_beams
|
||||
|
@ -855,6 +891,7 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
|
|||
def __init__(self, bos_token_id: int):
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len == 1:
|
||||
|
@ -882,6 +919,7 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
|
|||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len == self.max_length - 1:
|
||||
|
@ -898,6 +936,7 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor):
|
|||
the logits processor should only be used if necessary since it can slow down the generation method.
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# set all nan values to 0.0
|
||||
scores[scores != scores] = 0.0
|
||||
|
@ -935,7 +974,8 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len > self.regulation_start:
|
||||
for i in self.eos_token_id:
|
||||
|
@ -951,7 +991,8 @@ class LogitNormalization(LogitsProcessor, LogitsWarper):
|
|||
the scores are normalized when comparing the hypotheses.
|
||||
"""
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
scores = scores.log_softmax(dim=-1)
|
||||
return scores
|
||||
|
||||
|
@ -967,7 +1008,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
|||
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
||||
self.begin_index = begin_index
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if input_ids.shape[1] == self.begin_index:
|
||||
scores[:, self.begin_suppress_tokens] = -float("inf")
|
||||
|
||||
|
@ -981,7 +1023,8 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
|||
def __init__(self, suppress_tokens):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
scores[:, self.suppress_tokens] = -float("inf")
|
||||
return scores
|
||||
|
||||
|
@ -994,7 +1037,8 @@ class ForceTokensLogitsProcessor(LogitsProcessor):
|
|||
def __init__(self, force_token_map: List[List[int]]):
|
||||
self.force_token_map = dict(force_token_map)
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
generation_idx = input_ids.shape[-1]
|
||||
current_token = self.force_token_map.get(generation_idx, None)
|
||||
if current_token is not None:
|
||||
|
@ -1030,7 +1074,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||
self.begin_index -= 1
|
||||
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
scores[:, self.no_timestamps_token_id] = -float("inf")
|
||||
|
||||
|
@ -1089,7 +1134,8 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
|||
f"{guidance_scale}."
|
||||
)
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# simple check to make sure we have compatible batch sizes between our
|
||||
# logits scores (cond + uncond) and input ids (cond only)
|
||||
if scores.shape[0] != 2 * input_ids.shape[0]:
|
||||
|
|
Loading…
Reference in New Issue