Fix generation doctests (#30263)

* fix doctest

* fix torch doctest

* make CI happy

* raise error

* make fixup
This commit is contained in:
Raushan Turganbay 2024-05-01 00:02:26 +05:00 committed by GitHub
parent 2ecefc3959
commit b8ac4d035c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 15 deletions

View File

@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ..cache_utils import DynamicCache
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from .configuration_utils import GenerationConfig
from .logits_process import LogitsProcessorList
class CandidateGenerator:
@ -94,9 +94,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
logits_processor: "LogitsProcessorList",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
logits_processor: "LogitsProcessorList" = None,
):
# Make sure all data at the same device as assistant model
device = assistant_model.device
@ -145,15 +145,22 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.input_ids_key = "input_ids"
# Prepare generation-related options.
self.logits_processor = logits_processor
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
# avoid unnecessary warnings that min_length is larger than max_new_tokens
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
self.main_model_min_length = self.generation_config.min_length
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None
for processor in self.logits_processor:
if type(processor) == MinLengthLogitsProcessor:
raise ValueError(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
)
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""

View File

@ -528,9 +528,9 @@ class TFGenerationMixin:
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | logits | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.413 | 24.33%
| 262 | the | -1.414 | 24.33%
| 1110 | day | -2.609 | 7.36%
| 618 | when | -2.009 | 13.41%
| 618 | when | -2.010 | 13.40%
| 356 | we | -1.859 | 15.58%
| 460 | can | -2.508 | 8.14%
@ -549,7 +549,7 @@ class TFGenerationMixin:
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
>>> # use case, you might want to recompute it with `normalize_logits=True`.
>>> output_length = input_length + np.sum(transition_scores.numpy() < 0, axis=1)
>>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))

View File

@ -705,9 +705,9 @@ class GenerationMixin:
input_ids=input_ids,
assistant_model=assistant_model,
generation_config=generation_config,
logits_processor=logits_processor,
model_kwargs=model_kwargs,
inputs_tensor=inputs_tensor,
logits_processor=logits_processor,
)
return candidate_generator
@ -4601,24 +4601,18 @@ class GenerationMixin:
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
>>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
... ]
... )
>>> model.generation_config.min_length = 10
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> candidate_generator = AssistedCandidateGenerator(
... input_ids=input_ids,
... assistant_model=assistant_model,
... generation_config=model.generation_config,
... logits_processor=logits_processor,
... model_kwargs={},
... )
>>> outputs = model._assisted_decoding(
... input_ids,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
... stopping_criteria=stopping_criteria,
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)