Fix generation doctests (#30263)
* fix doctest * fix torch doctest * make CI happy * raise error * make fixup
This commit is contained in:
parent
2ecefc3959
commit
b8ac4d035c
|
@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
|||
import torch
|
||||
|
||||
from ..cache_utils import DynamicCache
|
||||
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .logits_process import LogitsProcessorList
|
||||
|
||||
|
||||
class CandidateGenerator:
|
||||
|
@ -94,9 +94,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||
input_ids: torch.LongTensor,
|
||||
assistant_model: "PreTrainedModel",
|
||||
generation_config: "GenerationConfig",
|
||||
logits_processor: "LogitsProcessorList",
|
||||
model_kwargs: Dict,
|
||||
inputs_tensor: Optional[torch.Tensor] = None,
|
||||
logits_processor: "LogitsProcessorList" = None,
|
||||
):
|
||||
# Make sure all data at the same device as assistant model
|
||||
device = assistant_model.device
|
||||
|
@ -145,15 +145,22 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||
self.input_ids_key = "input_ids"
|
||||
|
||||
# Prepare generation-related options.
|
||||
self.logits_processor = logits_processor
|
||||
self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
self.generation_config = copy.deepcopy(generation_config)
|
||||
self.generation_config.return_dict_in_generate = True
|
||||
self.generation_config.output_scores = True
|
||||
|
||||
# avoid unnecessary warnings that min_length is larger than max_new_tokens
|
||||
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
|
||||
self.main_model_min_length = self.generation_config.min_length
|
||||
self.generation_config.min_length = 0
|
||||
self.generation_config.min_new_tokens = None
|
||||
for processor in self.logits_processor:
|
||||
if type(processor) == MinLengthLogitsProcessor:
|
||||
raise ValueError(
|
||||
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
|
||||
"Please pass in `min_length` into `.generate()` instead"
|
||||
)
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
|
|
|
@ -528,9 +528,9 @@ class TFGenerationMixin:
|
|||
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
|
||||
... # | token | token string | logits | probability
|
||||
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
|
||||
| 262 | the | -1.413 | 24.33%
|
||||
| 262 | the | -1.414 | 24.33%
|
||||
| 1110 | day | -2.609 | 7.36%
|
||||
| 618 | when | -2.009 | 13.41%
|
||||
| 618 | when | -2.010 | 13.40%
|
||||
| 356 | we | -1.859 | 15.58%
|
||||
| 460 | can | -2.508 | 8.14%
|
||||
|
||||
|
@ -549,7 +549,7 @@ class TFGenerationMixin:
|
|||
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
|
||||
>>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
|
||||
>>> # use case, you might want to recompute it with `normalize_logits=True`.
|
||||
>>> output_length = input_length + np.sum(transition_scores.numpy() < 0, axis=1)
|
||||
>>> output_length = np.sum(transition_scores.numpy() < 0, axis=1)
|
||||
>>> length_penalty = model.generation_config.length_penalty
|
||||
>>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty)
|
||||
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
|
||||
|
|
|
@ -705,9 +705,9 @@ class GenerationMixin:
|
|||
input_ids=input_ids,
|
||||
assistant_model=assistant_model,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
model_kwargs=model_kwargs,
|
||||
inputs_tensor=inputs_tensor,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
return candidate_generator
|
||||
|
||||
|
@ -4601,24 +4601,18 @@ class GenerationMixin:
|
|||
>>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
>>> input_prompt = "It might be possible to"
|
||||
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = LogitsProcessorList(
|
||||
... [
|
||||
... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
|
||||
... ]
|
||||
... )
|
||||
>>> model.generation_config.min_length = 10
|
||||
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
>>> candidate_generator = AssistedCandidateGenerator(
|
||||
... input_ids=input_ids,
|
||||
... assistant_model=assistant_model,
|
||||
... generation_config=model.generation_config,
|
||||
... logits_processor=logits_processor,
|
||||
... model_kwargs={},
|
||||
... )
|
||||
>>> outputs = model._assisted_decoding(
|
||||
... input_ids,
|
||||
... candidate_generator=candidate_generator,
|
||||
... logits_processor=logits_processor,
|
||||
... stopping_criteria=stopping_criteria,
|
||||
... )
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
|
Loading…
Reference in New Issue