Generate: move generation_*.py src files into generation/*.py (#20096)
* move generation_*.py src files into generation/*.py * populate generation.__init__ with lazy loading * move imports and references from generation.xxx.object to generation.object
This commit is contained in:
parent
bac2d29a80
commit
f270b960d6
|
@ -56,7 +56,7 @@ Wenn Sie mehr als eine Eingabe haben, übergeben Sie die Eingabe als Liste:
|
|||
... ) # doctest: +SKIP
|
||||
```
|
||||
|
||||
Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation_utils.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`:
|
||||
Alle zusätzlichen Parameter für Ihre Aufgabe können auch in die [`pipeline`] aufgenommen werden. Die Aufgabe `Text-Generierung` hat eine [`~generation.GenerationMixin.generate`]-Methode mit mehreren Parametern zur Steuerung der Ausgabe. Wenn Sie zum Beispiel mehr als eine Ausgabe erzeugen wollen, setzen Sie den Parameter `num_return_sequences`:
|
||||
|
||||
```py
|
||||
>>> generator(
|
||||
|
|
|
@ -12,22 +12,22 @@ specific language governing permissions and limitations under the License.
|
|||
|
||||
# Utilities for Generation
|
||||
|
||||
This page lists all the utility functions used by [`~generation_utils.GenerationMixin.generate`],
|
||||
[`~generation_utils.GenerationMixin.greedy_search`],
|
||||
[`~generation_utils.GenerationMixin.contrastive_search`],
|
||||
[`~generation_utils.GenerationMixin.sample`],
|
||||
[`~generation_utils.GenerationMixin.beam_search`],
|
||||
[`~generation_utils.GenerationMixin.beam_sample`],
|
||||
[`~generation_utils.GenerationMixin.group_beam_search`], and
|
||||
[`~generation_utils.GenerationMixin.constrained_beam_search`].
|
||||
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
|
||||
[`~generation.GenerationMixin.greedy_search`],
|
||||
[`~generation.GenerationMixin.contrastive_search`],
|
||||
[`~generation.GenerationMixin.sample`],
|
||||
[`~generation.GenerationMixin.beam_search`],
|
||||
[`~generation.GenerationMixin.beam_sample`],
|
||||
[`~generation.GenerationMixin.group_beam_search`], and
|
||||
[`~generation.GenerationMixin.constrained_beam_search`].
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
|
||||
## Generate Outputs
|
||||
|
||||
The output of [`~generation_utils.GenerationMixin.generate`] is an instance of a subclass of
|
||||
The output of [`~generation.GenerationMixin.generate`] is an instance of a subclass of
|
||||
[`~utils.ModelOutput`]. This output is a data structure containing all the information returned
|
||||
by [`~generation_utils.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
|
||||
by [`~generation.GenerationMixin.generate`], but that can also be used as tuple or dictionary.
|
||||
|
||||
Here's an example:
|
||||
|
||||
|
@ -41,7 +41,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
|
|||
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
```
|
||||
|
||||
The `generation_output` object is a [`~generation_utils.GreedySearchDecoderOnlyOutput`], as we can
|
||||
The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can
|
||||
see in the documentation of that class below, it means it has the following attributes:
|
||||
|
||||
- `sequences`: the generated sequences of tokens
|
||||
|
@ -73,31 +73,31 @@ We document here all output types.
|
|||
|
||||
### GreedySearchOutput
|
||||
|
||||
[[autodoc]] generation_utils.GreedySearchDecoderOnlyOutput
|
||||
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation_utils.GreedySearchEncoderDecoderOutput
|
||||
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation_flax_utils.FlaxGreedySearchOutput
|
||||
[[autodoc]] generation.FlaxGreedySearchOutput
|
||||
|
||||
### SampleOutput
|
||||
|
||||
[[autodoc]] generation_utils.SampleDecoderOnlyOutput
|
||||
[[autodoc]] generation.SampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation_utils.SampleEncoderDecoderOutput
|
||||
[[autodoc]] generation.SampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation_flax_utils.FlaxSampleOutput
|
||||
[[autodoc]] generation.FlaxSampleOutput
|
||||
|
||||
### BeamSearchOutput
|
||||
|
||||
[[autodoc]] generation_utils.BeamSearchDecoderOnlyOutput
|
||||
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation_utils.BeamSearchEncoderDecoderOutput
|
||||
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
|
||||
|
||||
### BeamSampleOutput
|
||||
|
||||
[[autodoc]] generation_utils.BeamSampleDecoderOnlyOutput
|
||||
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation_utils.BeamSampleEncoderDecoderOutput
|
||||
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
||||
|
||||
## LogitsProcessor
|
||||
|
||||
|
|
|
@ -25,9 +25,9 @@ are common among all the models to:
|
|||
|
||||
The other methods that are common to each model are defined in [`~modeling_utils.ModuleUtilsMixin`]
|
||||
(for the PyTorch models) and [`~modeling_tf_utils.TFModuleUtilsMixin`] (for the TensorFlow models) or
|
||||
for text generation, [`~generation_utils.GenerationMixin`] (for the PyTorch models),
|
||||
[`~generation_tf_utils.TFGenerationMixin`] (for the TensorFlow models) and
|
||||
[`~generation_flax_utils.FlaxGenerationMixin`] (for the Flax/JAX models).
|
||||
for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
|
||||
[`~generation.TFGenerationMixin`] (for the TensorFlow models) and
|
||||
[`~generation.FlaxGenerationMixin`] (for the Flax/JAX models).
|
||||
|
||||
|
||||
## PreTrainedModel
|
||||
|
|
|
@ -14,13 +14,13 @@ specific language governing permissions and limitations under the License.
|
|||
|
||||
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
|
||||
|
||||
- PyTorch [`~generation_utils.GenerationMixin.generate`] is implemented in [`~generation_utils.GenerationMixin`].
|
||||
- TensorFlow [`~generation_tf_utils.TFGenerationMixin.generate`] is implemented in [`~generation_tf_utils.TFGenerationMixin`].
|
||||
- Flax/JAX [`~generation_flax_utils.FlaxGenerationMixin.generate`] is implemented in [`~generation_flax_utils.FlaxGenerationMixin`].
|
||||
- PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
|
||||
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
|
||||
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
|
||||
|
||||
## GenerationMixin
|
||||
|
||||
[[autodoc]] generation_utils.GenerationMixin
|
||||
[[autodoc]] generation.GenerationMixin
|
||||
- generate
|
||||
- greedy_search
|
||||
- sample
|
||||
|
@ -32,10 +32,10 @@ Each framework has a generate method for auto-regressive text generation impleme
|
|||
|
||||
## TFGenerationMixin
|
||||
|
||||
[[autodoc]] generation_tf_utils.TFGenerationMixin
|
||||
[[autodoc]] generation.TFGenerationMixin
|
||||
- generate
|
||||
|
||||
## FlaxGenerationMixin
|
||||
|
||||
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
|
||||
[[autodoc]] generation.FlaxGenerationMixin
|
||||
- generate
|
||||
|
|
|
@ -58,7 +58,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
|
|||
- Model predictions are intended to be identical to the original implementation when
|
||||
`forced_bos_token_id=0`. This only works, however, if the string you pass to
|
||||
[`fairseq.encode`] starts with a space.
|
||||
- [`~generation_utils.GenerationMixin.generate`] should be used for conditional generation tasks like
|
||||
- [`~generation.GenerationMixin.generate`] should be used for conditional generation tasks like
|
||||
summarization, see the example in that docstrings.
|
||||
- Models that load the *facebook/bart-large-cnn* weights will not have a `mask_token_id`, or be able to perform
|
||||
mask-filling tasks.
|
||||
|
@ -188,4 +188,4 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||
## FlaxBartForCausalLM
|
||||
|
||||
[[autodoc]] FlaxBartForCausalLM
|
||||
- __call__
|
||||
- __call__
|
||||
|
|
|
@ -23,7 +23,7 @@ The abstract from the paper is the following:
|
|||
*Understanding document images (e.g., invoices) is a core but challenging task since it requires complex functions such as reading text and a holistic understanding of the document. Current Visual Document Understanding (VDU) methods outsource the task of reading text to off-the-shelf Optical Character Recognition (OCR) engines and focus on the understanding task with the OCR outputs. Although such OCR-based approaches have shown promising performance, they suffer from 1) high computational costs for using OCR; 2) inflexibility of OCR models on languages or types of document; 3) OCR error propagation to the subsequent process. To address these issues, in this paper, we introduce a novel OCR-free VDU model named Donut, which stands for Document understanding transformer. As the first step in OCR-free VDU research, we propose a simple architecture (i.e., Transformer) with a pre-training objective (i.e., cross-entropy loss). Donut is conceptually simple yet effective. Through extensive experiments and analyses, we show a simple OCR-free VDU model, Donut, achieves state-of-the-art performances on various VDU tasks in terms of both speed and accuracy. In addition, we offer a synthetic data generator that helps the model pre-training to be flexible in various languages and domains.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/donut_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Donut high-level overview. Taken from the <a href="https://arxiv.org/abs/2111.15664">original paper</a>. </small>
|
||||
|
||||
|
@ -40,7 +40,7 @@ Tips:
|
|||
## Inference
|
||||
|
||||
Donut's [`VisionEncoderDecoder`] model accepts images as input and makes use of
|
||||
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||
[`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||
|
||||
The [`DonutFeatureExtractor`] class is responsible for preprocessing the input image and
|
||||
[`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`] decodes the generated target tokens to the target string. The
|
||||
|
@ -211,4 +211,4 @@ We refer to the [tutorial notebooks](https://github.com/NielsRogge/Transformers-
|
|||
## DonutSwinModel
|
||||
|
||||
[[autodoc]] DonutSwinModel
|
||||
- forward
|
||||
- forward
|
||||
|
|
|
@ -53,7 +53,7 @@ Tips:
|
|||
|
||||
### Generation
|
||||
|
||||
The [`~generation_utils.GenerationMixin.generate`] method can be used to generate text using GPT-J
|
||||
The [`~generation.GenerationMixin.generate`] method can be used to generate text using GPT-J
|
||||
model.
|
||||
|
||||
```python
|
||||
|
|
|
@ -38,7 +38,7 @@ Tips:
|
|||
## Inference
|
||||
|
||||
Speech2Text2's [`SpeechEncoderDecoderModel`] model accepts raw waveform input values from speech and
|
||||
makes use of [`~generation_utils.GenerationMixin.generate`] to translate the input speech
|
||||
makes use of [`~generation.GenerationMixin.generate`] to translate the input speech
|
||||
autoregressively to the target language.
|
||||
|
||||
The [`Wav2Vec2FeatureExtractor`] class is responsible for preprocessing the input speech and
|
||||
|
|
|
@ -225,7 +225,7 @@ batch) leads to very slow training on TPU.
|
|||
|
||||
## Inference
|
||||
|
||||
At inference time, it is recommended to use [`~generation_utils.GenerationMixin.generate`]. This
|
||||
At inference time, it is recommended to use [`~generation.GenerationMixin.generate`]. This
|
||||
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
|
||||
and auto-regressively generates the decoder output. Check out [this blog post](https://huggingface.co/blog/how-to-generate) to know all the details about generating text with Transformers.
|
||||
There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encoder-decoder) which explains how
|
||||
|
@ -244,7 +244,7 @@ Das Haus ist wunderbar.
|
|||
```
|
||||
|
||||
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
|
||||
[`~generation_utils.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
|
||||
[`~generation.GenerationMixin.generate`], make sure you start it with the `pad_token_id`.
|
||||
|
||||
The example above only shows a single example. You can also do batched inference, like so:
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ show that the TrOCR model outperforms the current state-of-the-art models on bot
|
|||
tasks.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/trocr_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> TrOCR architecture. Taken from the <a href="https://arxiv.org/abs/2109.10282">original paper</a>. </small>
|
||||
|
||||
|
@ -53,7 +53,7 @@ Tips:
|
|||
## Inference
|
||||
|
||||
TrOCR's [`VisionEncoderDecoder`] model accepts images as input and makes use of
|
||||
[`~generation_utils.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||
[`~generation.GenerationMixin.generate`] to autoregressively generate text given the input image.
|
||||
|
||||
The [`ViTFeatureExtractor`/`DeiTFeatureExtractor`] class is responsible for preprocessing the input image and
|
||||
[`RobertaTokenizer`/`XLMRobertaTokenizer`] decodes the generated target tokens to the target string. The
|
||||
|
@ -64,20 +64,20 @@ into a single instance to both extract the input features and decode the predict
|
|||
|
||||
``` py
|
||||
>>> from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
>>> import requests
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
||||
|
||||
>>> # load image from the IAM dataset
|
||||
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
||||
>>> # load image from the IAM dataset
|
||||
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
>>> generated_ids = model.generate(pixel_values)
|
||||
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
```
|
||||
|
||||
See the [model hub](https://huggingface.co/models?filter=trocr) to look for TrOCR checkpoints.
|
||||
|
|
|
@ -24,7 +24,7 @@ The abstract from the paper is the following:
|
|||
Tips:
|
||||
|
||||
- The model usually performs well without requiring any finetuning.
|
||||
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference.
|
||||
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation.GenerationMixin.generate`] function for inference.
|
||||
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ If you have more than one input, pass your input as a list:
|
|||
... ) # doctest: +SKIP
|
||||
```
|
||||
|
||||
Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation_utils.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter:
|
||||
Any additional parameters for your task can also be included in the [`pipeline`]. The `text-generation` task has a [`~generation.GenerationMixin.generate`] method with several parameters for controlling the output. For example, if you want to generate more than one output, set the `num_return_sequences` parameter:
|
||||
|
||||
```py
|
||||
>>> generator(
|
||||
|
|
|
@ -544,7 +544,7 @@ Hugging Face is based in DUMBO, New York City, and ...
|
|||
This outputs a (hopefully) coherent next token following the original sequence, which in our case is the word *is* or
|
||||
*features*.
|
||||
|
||||
In the next section, we show how [`generation_utils.GenerationMixin.generate`] can be used to
|
||||
In the next section, we show how [`generation.GenerationMixin.generate`] can be used to
|
||||
generate multiple tokens up to a specified length instead of one token at a time.
|
||||
|
||||
### Text Generation
|
||||
|
@ -1094,10 +1094,10 @@ The following examples demonstrate how to use a [`pipeline`] and a model and tok
|
|||
... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
... )
|
||||
>>> print("\n".join([f"Class {d['label']} with score {round(d['score'], 4)}" for d in result]))
|
||||
Class lynx, catamount with score 0.4335
|
||||
Class lynx, catamount with score 0.4335
|
||||
Class cougar, puma, catamount, mountain lion, painter, panther, Felis concolor with score 0.0348
|
||||
Class snow leopard, ounce, Panthera uncia with score 0.0324
|
||||
Class Egyptian cat with score 0.0239
|
||||
Class snow leopard, ounce, Panthera uncia with score 0.0324
|
||||
Class Egyptian cat with score 0.0239
|
||||
Class tiger cat with score 0.0229
|
||||
```
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ Si tienes más de una entrada, pásala como una lista:
|
|||
... )
|
||||
```
|
||||
|
||||
Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation_utils.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`:
|
||||
Cualquier parámetro adicional para tu tarea también se puede incluir en el [`pipeline`]. La tarea `text-generation` tiene un método [`~generation.GenerationMixin.generate`] con varios parámetros para controlar la salida. Por ejemplo, si deseas generar más de una salida, defínelo en el parámetro `num_return_sequences`:
|
||||
|
||||
```py
|
||||
>>> generator(
|
||||
|
|
|
@ -26,7 +26,7 @@ Dai un'occhiata alla documentazione di [`pipeline`] per una lista completa dei c
|
|||
|
||||
## Utilizzo della Pipeline
|
||||
|
||||
Nonostante ogni compito abbia una [`pipeline`] associata, è più semplice utilizzare l'astrazione generica della [`pipeline`] che contiene tutte quelle specifiche per ogni mansione. La [`pipeline`] carica automaticamente un modello predefinito e un tokenizer in grado di fare inferenza per il tuo compito.
|
||||
Nonostante ogni compito abbia una [`pipeline`] associata, è più semplice utilizzare l'astrazione generica della [`pipeline`] che contiene tutte quelle specifiche per ogni mansione. La [`pipeline`] carica automaticamente un modello predefinito e un tokenizer in grado di fare inferenza per il tuo compito.
|
||||
|
||||
1. Inizia creando una [`pipeline`] e specificando il compito su cui fare inferenza:
|
||||
|
||||
|
@ -56,7 +56,7 @@ Se hai più di un input, inseriscilo in una lista:
|
|||
... ) # doctest: +SKIP
|
||||
```
|
||||
|
||||
Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation_utils.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`:
|
||||
Qualsiasi parametro addizionale per il tuo compito può essere incluso nella [`pipeline`]. La mansione `text-generation` ha un metodo [`~generation.GenerationMixin.generate`] con diversi parametri per controllare l'output. Ad esempio, se desideri generare più di un output, utilizza il parametro `num_return_sequences`:
|
||||
|
||||
```py
|
||||
>>> generator(
|
||||
|
|
|
@ -61,7 +61,7 @@ Se tiver mais de uma entrada, passe-a como uma lista:
|
|||
```
|
||||
|
||||
Qualquer parâmetro adicional para a sua tarefa também pode ser incluído no [`pipeline`]. A tarefa `text-generation` tem um método
|
||||
[`~generation_utils.GenerationMixin.generate`] com vários parâmetros para controlar a saída.
|
||||
[`~generation.GenerationMixin.generate`] com vários parâmetros para controlar a saída.
|
||||
Por exemplo, se quiser gerar mais de uma saída, defina-a no parâmetro `num_return_sequences`:
|
||||
|
||||
```py
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from transformers import BartConfig
|
||||
from transformers.generation_utils import GenerationMixin
|
||||
from transformers.generation import GenerationMixin
|
||||
|
||||
|
||||
def _convert_past_list_to_tuple(past_key_values):
|
||||
|
|
|
@ -97,6 +97,7 @@ _import_structure = {
|
|||
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
|
||||
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
|
||||
"file_utils": [],
|
||||
"generation": [],
|
||||
"hf_argparser": ["HfArgumentParser"],
|
||||
"integrations": [
|
||||
"is_comet_available",
|
||||
|
@ -821,38 +822,40 @@ else:
|
|||
"TextDatasetForNextSentencePrediction",
|
||||
]
|
||||
_import_structure["deepspeed"] = []
|
||||
_import_structure["generation_beam_constraints"] = [
|
||||
"Constraint",
|
||||
"ConstraintListState",
|
||||
"DisjunctiveConstraint",
|
||||
"PhrasalConstraint",
|
||||
]
|
||||
_import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"]
|
||||
_import_structure["generation_logits_process"] = [
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
"MinLengthLogitsProcessor",
|
||||
"NoBadWordsLogitsProcessor",
|
||||
"NoRepeatNGramLogitsProcessor",
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
]
|
||||
_import_structure["generation_stopping_criteria"] = [
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
]
|
||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||
_import_structure["generation_utils"] = []
|
||||
_import_structure["generation"].extend(
|
||||
[
|
||||
"Constraint",
|
||||
"ConstraintListState",
|
||||
"DisjunctiveConstraint",
|
||||
"PhrasalConstraint",
|
||||
"BeamScorer",
|
||||
"BeamSearchScorer",
|
||||
"ConstrainedBeamSearchScorer",
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
"MinLengthLogitsProcessor",
|
||||
"NoBadWordsLogitsProcessor",
|
||||
"NoRepeatNGramLogitsProcessor",
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"GenerationMixin",
|
||||
"top_k_top_p_filtering",
|
||||
]
|
||||
)
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel"]
|
||||
|
||||
|
@ -2278,21 +2281,25 @@ else:
|
|||
_import_structure["activations_tf"] = []
|
||||
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
|
||||
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
|
||||
_import_structure["generation_tf_logits_process"] = [
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
"TFMinLengthLogitsProcessor",
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
]
|
||||
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
|
||||
_import_structure["generation_tf_utils"] = []
|
||||
_import_structure["generation"].extend(
|
||||
[
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
"TFMinLengthLogitsProcessor",
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
"TFGenerationMixin",
|
||||
"tf_top_k_top_p_filtering",
|
||||
]
|
||||
)
|
||||
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
|
||||
_import_structure["modeling_tf_outputs"] = []
|
||||
_import_structure["modeling_tf_utils"] = [
|
||||
|
@ -2915,18 +2922,21 @@ except OptionalDependencyNotAvailable:
|
|||
name for name in dir(dummy_flax_objects) if not name.startswith("_")
|
||||
]
|
||||
else:
|
||||
_import_structure["generation_flax_logits_process"] = [
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
]
|
||||
_import_structure["generation_flax_utils"] = []
|
||||
_import_structure["generation"].extend(
|
||||
[
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
"FlaxGenerationMixin",
|
||||
]
|
||||
)
|
||||
_import_structure["modeling_flax_outputs"] = []
|
||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||
_import_structure["models.albert"].extend(
|
||||
|
@ -3834,38 +3844,37 @@ if TYPE_CHECKING:
|
|||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
)
|
||||
from .generation_beam_constraints import (
|
||||
from .generation import (
|
||||
BeamScorer,
|
||||
BeamSearchScorer,
|
||||
ConstrainedBeamSearchScorer,
|
||||
Constraint,
|
||||
ConstraintListState,
|
||||
DisjunctiveConstraint,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerationMixin,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PhrasalConstraint,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from .generation_stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from .generation_utils import top_k_top_p_filtering
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
# PyTorch model imports
|
||||
|
@ -5037,9 +5046,10 @@ if TYPE_CHECKING:
|
|||
|
||||
# Benchmarks
|
||||
from .benchmark.benchmark_tf import TensorFlowBenchmark
|
||||
from .generation_tf_logits_process import (
|
||||
from .generation import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFGenerationMixin,
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFLogitsWarper,
|
||||
|
@ -5050,8 +5060,8 @@ if TYPE_CHECKING:
|
|||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
from .generation_tf_utils import tf_top_k_top_p_filtering
|
||||
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
from .modeling_tf_layoutlm import (
|
||||
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -5541,9 +5551,10 @@ if TYPE_CHECKING:
|
|||
# They will raise an import error if the user tries to instantiate / use them.
|
||||
from .utils.dummy_flax_objects import *
|
||||
else:
|
||||
from .generation_flax_logits_process import (
|
||||
from .generation import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxGenerationMixin,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["beam_constraints"] = [
|
||||
"Constraint",
|
||||
"ConstraintListState",
|
||||
"DisjunctiveConstraint",
|
||||
"PhrasalConstraint",
|
||||
]
|
||||
_import_structure["beam_search"] = [
|
||||
"BeamHypotheses",
|
||||
"BeamScorer",
|
||||
"BeamSearchScorer",
|
||||
"ConstrainedBeamSearchScorer",
|
||||
]
|
||||
_import_structure["logits_process"] = [
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
"MinLengthLogitsProcessor",
|
||||
"NoBadWordsLogitsProcessor",
|
||||
"NoRepeatNGramLogitsProcessor",
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"LogitNormalization",
|
||||
]
|
||||
_import_structure["stopping_criteria"] = [
|
||||
"MaxNewTokensCriteria",
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"validate_stopping_criteria",
|
||||
]
|
||||
_import_structure["utils"] = [
|
||||
"GenerationMixin",
|
||||
"top_k_top_p_filtering",
|
||||
"GreedySearchEncoderDecoderOutput",
|
||||
"GreedySearchDecoderOnlyOutput",
|
||||
"SampleEncoderDecoderOutput",
|
||||
"SampleDecoderOnlyOutput",
|
||||
"BeamSearchEncoderDecoderOutput",
|
||||
"BeamSearchDecoderOnlyOutput",
|
||||
"BeamSampleEncoderDecoderOutput",
|
||||
"BeamSampleDecoderOnlyOutput",
|
||||
"ContrastiveSearchEncoderDecoderOutput",
|
||||
"ContrastiveSearchDecoderOnlyOutput",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tf_logits_process"] = [
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
"TFMinLengthLogitsProcessor",
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
]
|
||||
_import_structure["tf_utils"] = [
|
||||
"TFGenerationMixin",
|
||||
"tf_top_k_top_p_filtering",
|
||||
"TFGreedySearchDecoderOnlyOutput",
|
||||
"TFGreedySearchEncoderDecoderOutput",
|
||||
"TFSampleEncoderDecoderOutput",
|
||||
"TFSampleDecoderOnlyOutput",
|
||||
"TFBeamSearchEncoderDecoderOutput",
|
||||
"TFBeamSearchDecoderOnlyOutput",
|
||||
"TFBeamSampleEncoderDecoderOutput",
|
||||
"TFBeamSampleDecoderOnlyOutput",
|
||||
"TFContrastiveSearchEncoderDecoderOutput",
|
||||
"TFContrastiveSearchDecoderOnlyOutput",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["flax_logits_process"] = [
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
]
|
||||
_import_structure["flax_utils"] = [
|
||||
"FlaxGenerationMixin",
|
||||
"FlaxGreedySearchOutput",
|
||||
"FlaxSampleOutput",
|
||||
"FlaxBeamSearchOutput",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
validate_stopping_criteria,
|
||||
)
|
||||
from .utils import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
ContrastiveSearchDecoderOnlyOutput,
|
||||
ContrastiveSearchEncoderDecoderOutput,
|
||||
GenerationMixin,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFLogitsWarper,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
from .tf_utils import (
|
||||
TFBeamSampleDecoderOnlyOutput,
|
||||
TFBeamSampleEncoderDecoderOutput,
|
||||
TFBeamSearchDecoderOnlyOutput,
|
||||
TFBeamSearchEncoderDecoderOutput,
|
||||
TFContrastiveSearchDecoderOnlyOutput,
|
||||
TFContrastiveSearchEncoderDecoderOutput,
|
||||
TFGenerationMixin,
|
||||
TFGreedySearchDecoderOnlyOutput,
|
||||
TFGreedySearchEncoderDecoderOutput,
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
|
@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .generation_beam_constraints import Constraint, ConstraintListState
|
||||
from .utils import add_start_docstrings
|
||||
from ..utils import add_start_docstrings
|
||||
from .beam_constraints import Constraint, ConstraintListState
|
||||
|
||||
|
||||
PROCESS_INPUTS_DOCSTRING = r"""
|
|
@ -19,8 +19,8 @@ import jax
|
|||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .utils import add_start_docstrings
|
||||
from .utils.logging import get_logger
|
||||
from ..utils import add_start_docstrings
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
|
@ -0,0 +1,947 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
||||
from ..models.auto import (
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..utils import ModelOutput, logging
|
||||
from .flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxGreedySearchOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jnp.ndarray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxSampleOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using sampling.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jnp.ndarray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBeamSearchOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
scores (`jnp.ndarray` of shape `(batch_size,)`):
|
||||
The scores (log probabilities) of the generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jnp.ndarray = None
|
||||
scores: jnp.ndarray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class GreedyState:
|
||||
cur_len: jnp.ndarray
|
||||
sequences: jnp.ndarray
|
||||
running_token: jnp.ndarray
|
||||
is_sent_finished: jnp.ndarray
|
||||
model_kwargs: Dict[str, jnp.ndarray]
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class SampleState:
|
||||
cur_len: jnp.ndarray
|
||||
sequences: jnp.ndarray
|
||||
running_token: jnp.ndarray
|
||||
is_sent_finished: jnp.ndarray
|
||||
prng_key: jnp.ndarray
|
||||
model_kwargs: Dict[str, jnp.ndarray]
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class BeamSearchState:
|
||||
cur_len: jnp.ndarray
|
||||
running_sequences: jnp.ndarray
|
||||
running_scores: jnp.ndarray
|
||||
sequences: jnp.ndarray
|
||||
scores: jnp.ndarray
|
||||
is_sent_finished: jnp.ndarray
|
||||
model_kwargs: Dict[str, jnp.ndarray]
|
||||
|
||||
|
||||
class FlaxGenerationMixin:
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in
|
||||
[`FlaxPreTrainedModel`].
|
||||
|
||||
The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _run_loop_in_debug(cond_fn, body_fn, init_state):
|
||||
"""
|
||||
Run generation in untraced mode. This should only be used for debugging purposes.
|
||||
"""
|
||||
state = init_state
|
||||
while cond_fn(state):
|
||||
state = body_fn(state)
|
||||
return state
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
||||
}
|
||||
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _expand_to_num_beams(tensor, num_beams):
|
||||
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
||||
|
||||
def _adapt_logits_for_beam_search(self, logits):
|
||||
"""
|
||||
This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
|
||||
search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
|
||||
"""
|
||||
return logits
|
||||
|
||||
def _validate_model_class(self):
|
||||
"""
|
||||
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
|
||||
right class to use.
|
||||
"""
|
||||
if not hasattr(self, "prepare_inputs_for_generation"):
|
||||
generate_compatible_mappings = [
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
]
|
||||
generate_compatible_classes = set()
|
||||
for model_mapping in generate_compatible_mappings:
|
||||
supported_models = model_mapping.get(type(self.config), default=None)
|
||||
if supported_models is not None:
|
||||
generate_compatible_classes.add(supported_models.__name__)
|
||||
exception_message = (
|
||||
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
|
||||
"it doesn't have a language model head."
|
||||
)
|
||||
if generate_compatible_classes:
|
||||
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
|
||||
raise TypeError(exception_message)
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.__call__).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
raise ValueError(
|
||||
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
||||
" generate arguments will also show up in this list)"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
prng_key: Optional[jnp.ndarray] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head. The method supports the following
|
||||
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
|
||||
defined in the model's config (`config.json`) which in turn defaults to the
|
||||
[`~modeling_utils.PretrainedConfig`] of the model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Most of these parameters are explained in more detail in [this blog
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
|
||||
Parameters:
|
||||
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_length (`int`, *optional*, defaults to `model.config.max_length`):
|
||||
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
||||
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
|
||||
the prompt.
|
||||
max_new_tokens (`int`, *optional*):
|
||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
do_sample (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
The value used to module the next token probabilities.
|
||||
top_k (`int`, *optional*, defaults to 50):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
|
||||
are kept for generation.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
bos_token_id (`int`, *optional*):
|
||||
The id of the *beginning-of-sequence* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
decoder_start_token_id (`int`, *optional*):
|
||||
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
trace (`bool`, *optional*, defaults to `True`):
|
||||
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
|
||||
considerably slower runtime.
|
||||
params (`Dict[str, jnp.ndarray]`, *optional*):
|
||||
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
|
||||
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
|
||||
should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> input_context = "The dog"
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
|
||||
>>> # generate candidates using sampling
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
```"""
|
||||
# Validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# set init values
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
)
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
||||
|
||||
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
|
||||
if not self.config.is_encoder_decoder and not trace:
|
||||
if pad_token_id is not None and jnp.sum(input_ids[:, -1] == pad_token_id) > 0:
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
|
||||
# prepare decoder_input_ids for generation
|
||||
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||
|
||||
# Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
if max_length is None and max_new_tokens is None:
|
||||
warnings.warn(
|
||||
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to "
|
||||
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
|
||||
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
|
||||
"using `max_new_tokens` to control the maximum length of the generation.",
|
||||
UserWarning,
|
||||
)
|
||||
elif max_length is None and max_new_tokens is not None:
|
||||
max_length = max_new_tokens + input_ids_seq_length
|
||||
elif max_length is not None and max_new_tokens is not None:
|
||||
raise ValueError(
|
||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
||||
" documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
# default to config if still None
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
|
||||
if min_length is not None and min_length > max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||
f"length ({max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= max_length:
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
|
||||
"`max_new_tokens`."
|
||||
)
|
||||
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
|
||||
if not do_sample and num_beams == 1:
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
return self._greedy_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif do_sample and num_beams == 1:
|
||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
return self._sample(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
prng_key,
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif not do_sample and num_beams > 1:
|
||||
# broadcast input_ids & encoder_outputs
|
||||
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
||||
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
|
||||
)
|
||||
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||
model_kwargs["attention_mask"], num_beams=num_beams
|
||||
)
|
||||
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
|
||||
return self._beam_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
|
||||
instances used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# init warp parameters
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
top_p = top_p if top_p is not None else self.config.top_p
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
# instantiate warpers list
|
||||
warpers = FlaxLogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if temperature is not None and temperature != 1.0:
|
||||
warpers.append(FlaxTemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
|
||||
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
no_repeat_ngram_size: int,
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
eos_token_id: int,
|
||||
forced_bos_token_id: int,
|
||||
forced_eos_token_id: int,
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
|
||||
instances used to modify the scores of the language model head.
|
||||
"""
|
||||
processors = FlaxLogitsProcessorList()
|
||||
|
||||
# init warp parameters
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
)
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
|
||||
if forced_bos_token_id is not None:
|
||||
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
return processors
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||
|
||||
# per batch-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = GreedyState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
running_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def greedy_search_cond_fn(state):
|
||||
"""state termination condition fn."""
|
||||
has_reached_max_length = state.cur_len == max_length
|
||||
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||
return ~finish_generation
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply min_length, ...
|
||||
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||
|
||||
next_token = jnp.argmax(logits, axis=-1)
|
||||
|
||||
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
return GreedyState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
running_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[1] > 1:
|
||||
state = greedy_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||
|
||||
return FlaxGreedySearchOutput(sequences=state.sequences)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
prng_key: Optional[jnp.ndarray] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||
|
||||
# per batch-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = SampleState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
running_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
prng_key=prng_key,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def sample_search_cond_fn(state):
|
||||
"""state termination condition fn."""
|
||||
has_reached_max_length = state.cur_len == max_length
|
||||
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||
return ~finish_generation
|
||||
|
||||
def sample_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply min_length, ...
|
||||
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||
# apply top_p, top_k, temperature
|
||||
logits = logits_warper(logits, logits, state.cur_len)
|
||||
|
||||
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
||||
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
|
||||
return SampleState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
running_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
prng_key=prng_key_next,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[1] > 1:
|
||||
state = sample_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
|
||||
return FlaxSampleOutput(sequences=state.sequences)
|
||||
|
||||
def _beam_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
):
|
||||
"""
|
||||
This beam search function is heavily inspired by Flax's official example:
|
||||
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
||||
"""
|
||||
|
||||
def flatten_beam_dim(tensor):
|
||||
"""Flattens the first two dimensions of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
||||
|
||||
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
||||
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
||||
|
||||
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
||||
"""
|
||||
Gathers the beam slices indexed by beam_indices into new beam array.
|
||||
"""
|
||||
batch_indices = jnp.reshape(
|
||||
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
||||
)
|
||||
|
||||
def gather_fn(tensor):
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
else:
|
||||
return tensor[batch_indices, beam_indices]
|
||||
|
||||
return jax.tree_util.tree_map(gather_fn, nested)
|
||||
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
batch_size, num_beams, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch,beam-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
||||
|
||||
# per batch,beam-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
||||
|
||||
# per batch,beam-item score, logprobs
|
||||
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
||||
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
|
||||
# flatten beam dim
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
||||
)
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = BeamSearchState(
|
||||
cur_len=cur_len,
|
||||
running_sequences=running_sequences,
|
||||
running_scores=running_scores,
|
||||
sequences=sequences,
|
||||
scores=scores,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def beam_search_cond_fn(state):
|
||||
"""beam search state termination condition fn."""
|
||||
|
||||
# 1. is less than max length?
|
||||
not_max_length_yet = state.cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
|
||||
worst_finished_score = jnp.where(
|
||||
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
||||
)
|
||||
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
||||
|
||||
# 3. is there still a beam that has not finished?
|
||||
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
||||
|
||||
return not_max_length_yet & still_open_beam & improvement_still_possible
|
||||
|
||||
def beam_search_body_fn(state, input_ids_length=1):
|
||||
"""beam search state update fn."""
|
||||
# 1. Forward current tokens
|
||||
# Collect the current position slice along length to feed the fast
|
||||
# autoregressive decoder model. Flatten the beam dimension into batch
|
||||
# dimension for feeding into the model.
|
||||
# unflatten beam dimension
|
||||
# Unflatten beam dimension in attention cache arrays
|
||||
input_token = flatten_beam_dim(
|
||||
lax.dynamic_slice(
|
||||
state.running_sequences,
|
||||
(0, 0, state.cur_len - input_ids_length),
|
||||
(batch_size, num_beams, input_ids_length),
|
||||
)
|
||||
)
|
||||
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
||||
|
||||
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
||||
cache = jax.tree_util.tree_map(
|
||||
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
||||
)
|
||||
|
||||
# adapt logits for FlaxMarianMTModel
|
||||
logits = self._adapt_logits_for_beam_search(logits)
|
||||
|
||||
# 2. Compute log probs
|
||||
# get log probabilities from logits,
|
||||
# process logits with processors (*e.g.* min_length, ...), and
|
||||
# add new logprobs to existing running logprobs scores.
|
||||
log_probs = jax.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(
|
||||
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
||||
)
|
||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
||||
vocab_size = log_probs.shape[2]
|
||||
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
||||
|
||||
# 3. Retrieve top-K
|
||||
# Each item in batch has num_beams * vocab_size candidate sequences.
|
||||
# For each item, get the top 2*k candidates with the highest log-
|
||||
# probabilities. We gather the top 2*K beams here so that even if the best
|
||||
# K sequences reach EOS simultaneously, we have another K sequences
|
||||
# remaining to continue the live beam search.
|
||||
# Gather the top 2*K scores from _all_ beams.
|
||||
# Gather 2*k top beams.
|
||||
# Recover the beam index by floor division.
|
||||
# Recover token id by modulo division and expand Id array for broadcasting.
|
||||
# Update sequences for the 2*K top-k new sequences.
|
||||
beams_to_keep = 2 * num_beams
|
||||
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
||||
topk_beam_indices = topk_indices // vocab_size
|
||||
topk_running_sequences = gather_beams(
|
||||
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
||||
)
|
||||
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
||||
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
||||
|
||||
# 4. Check which sequences have ended
|
||||
# Update current sequences:
|
||||
# Did any of these sequences reach an end marker?
|
||||
# To prevent these just finished sequences from being added to the current sequences
|
||||
# set of active beam search sequences, set their log probs to a very large
|
||||
# negative value.
|
||||
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
||||
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
||||
# 5. Get running sequences scores for next
|
||||
# Determine the top k beam indices (from top 2*k beams) from log probs
|
||||
# and gather top k beams (from top 2*k beams).
|
||||
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
|
||||
next_running_sequences, next_running_scores = gather_beams(
|
||||
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
||||
)
|
||||
|
||||
# 6. Process topk logits
|
||||
# Further process log probs:
|
||||
# - add length penalty
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
||||
beams_in_batch_are_full = (
|
||||
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
||||
& early_stopping
|
||||
)
|
||||
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||
topk_log_probs += add_penalty * np.array(-1.0e7)
|
||||
|
||||
# 7. Get scores, sequences, is sentence finished for next.
|
||||
# Combine sequences, scores, and flags along the beam dimension and compare
|
||||
# new finished sequence scores to existing finished scores and select the
|
||||
# best from the new set of beams
|
||||
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
||||
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
||||
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
||||
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
||||
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
||||
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
||||
)
|
||||
|
||||
# 8. Update model kwargs.
|
||||
# Determine the top k beam indices from the original set of all beams.
|
||||
# With these, gather the top k beam-associated caches.
|
||||
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
||||
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
||||
model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
|
||||
return BeamSearchState(
|
||||
cur_len=state.cur_len + 1,
|
||||
running_scores=next_running_scores,
|
||||
running_sequences=next_running_sequences,
|
||||
scores=next_scores,
|
||||
sequences=next_sequences,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[-1] > 1:
|
||||
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
|
||||
# Account for the edge-case where there are no finished sequences for a
|
||||
# particular batch item. If so, return running sequences for that batch item.
|
||||
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
||||
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
||||
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
||||
|
||||
# take best beam for each batch
|
||||
sequences = sequences[:, -1]
|
||||
scores = scores[:, -1]
|
||||
|
||||
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
|
@ -20,8 +20,8 @@ from typing import Callable, Iterable, List, Optional, Tuple
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .utils import add_start_docstrings
|
||||
from .utils.logging import get_logger
|
||||
from ..utils import add_start_docstrings
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from .utils import add_start_docstrings
|
||||
from ..utils import add_start_docstrings
|
||||
|
||||
|
||||
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
|
|
@ -19,9 +19,9 @@ from typing import List, Tuple
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .tf_utils import stable_softmax
|
||||
from .utils import add_start_docstrings
|
||||
from .utils.logging import get_logger
|
||||
from ..tf_utils import stable_softmax
|
||||
from ..utils import add_start_docstrings
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -14,934 +14,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from .generation import FlaxGenerationMixin
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
|
||||
from .generation_flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
from .models.auto import (
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from .utils import ModelOutput, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxGreedySearchOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jnp.ndarray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxSampleOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using sampling.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jnp.ndarray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBeamSearchOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
||||
The generated sequences.
|
||||
scores (`jnp.ndarray` of shape `(batch_size,)`):
|
||||
The scores (log probabilities) of the generated sequences.
|
||||
"""
|
||||
|
||||
sequences: jnp.ndarray = None
|
||||
scores: jnp.ndarray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class GreedyState:
|
||||
cur_len: jnp.ndarray
|
||||
sequences: jnp.ndarray
|
||||
running_token: jnp.ndarray
|
||||
is_sent_finished: jnp.ndarray
|
||||
model_kwargs: Dict[str, jnp.ndarray]
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class SampleState:
|
||||
cur_len: jnp.ndarray
|
||||
sequences: jnp.ndarray
|
||||
running_token: jnp.ndarray
|
||||
is_sent_finished: jnp.ndarray
|
||||
prng_key: jnp.ndarray
|
||||
model_kwargs: Dict[str, jnp.ndarray]
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class BeamSearchState:
|
||||
cur_len: jnp.ndarray
|
||||
running_sequences: jnp.ndarray
|
||||
running_scores: jnp.ndarray
|
||||
sequences: jnp.ndarray
|
||||
scores: jnp.ndarray
|
||||
is_sent_finished: jnp.ndarray
|
||||
model_kwargs: Dict[str, jnp.ndarray]
|
||||
|
||||
|
||||
class FlaxGenerationMixin:
|
||||
"""
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in
|
||||
[`FlaxPreTrainedModel`].
|
||||
|
||||
The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
|
||||
`num_beams=1` and `do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
|
||||
and `do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
|
||||
and `do_sample=False`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _run_loop_in_debug(cond_fn, body_fn, init_state):
|
||||
"""
|
||||
Run generation in untraced mode. This should only be used for debugging purposes.
|
||||
"""
|
||||
state = init_state
|
||||
while cond_fn(state):
|
||||
state = body_fn(state)
|
||||
return state
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
||||
}
|
||||
model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _expand_to_num_beams(tensor, num_beams):
|
||||
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
||||
|
||||
def _adapt_logits_for_beam_search(self, logits):
|
||||
"""
|
||||
This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
|
||||
search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
|
||||
"""
|
||||
return logits
|
||||
|
||||
def _validate_model_class(self):
|
||||
"""
|
||||
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
|
||||
right class to use.
|
||||
"""
|
||||
if not hasattr(self, "prepare_inputs_for_generation"):
|
||||
generate_compatible_mappings = [
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
]
|
||||
generate_compatible_classes = set()
|
||||
for model_mapping in generate_compatible_mappings:
|
||||
supported_models = model_mapping.get(type(self.config), default=None)
|
||||
if supported_models is not None:
|
||||
generate_compatible_classes.add(supported_models.__name__)
|
||||
exception_message = (
|
||||
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
|
||||
"it doesn't have a language model head."
|
||||
)
|
||||
if generate_compatible_classes:
|
||||
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
|
||||
raise TypeError(exception_message)
|
||||
|
||||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||
unused_model_args = []
|
||||
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||
if "kwargs" in model_args:
|
||||
model_args |= set(inspect.signature(self.__call__).parameters)
|
||||
for key, value in model_kwargs.items():
|
||||
if value is not None and key not in model_args:
|
||||
unused_model_args.append(key)
|
||||
|
||||
if unused_model_args:
|
||||
raise ValueError(
|
||||
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
||||
" generate arguments will also show up in this list)"
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: jnp.ndarray,
|
||||
max_length: Optional[int] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
prng_key: Optional[jnp.ndarray] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
num_beams: Optional[int] = None,
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head. The method supports the following
|
||||
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
|
||||
`num_beams=1` and `do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
|
||||
and `do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
|
||||
and `do_sample=False`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
|
||||
defined in the model's config (`config.json`) which in turn defaults to the
|
||||
[`~modeling_utils.PretrainedConfig`] of the model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Most of these parameters are explained in more detail in [this blog
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
|
||||
Parameters:
|
||||
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_length (`int`, *optional*, defaults to `model.config.max_length`):
|
||||
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
|
||||
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
|
||||
the prompt.
|
||||
max_new_tokens (`int`, *optional*):
|
||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
do_sample (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
The value used to module the next token probabilities.
|
||||
top_k (`int`, *optional*, defaults to 50):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
|
||||
are kept for generation.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
bos_token_id (`int`, *optional*):
|
||||
The id of the *beginning-of-sequence* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
decoder_start_token_id (`int`, *optional*):
|
||||
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
trace (`bool`, *optional*, defaults to `True`):
|
||||
Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a
|
||||
considerably slower runtime.
|
||||
params (`Dict[str, jnp.ndarray]`, *optional*):
|
||||
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
|
||||
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
|
||||
should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`].
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> input_context = "The dog"
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
|
||||
>>> # generate candidates using sampling
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
```"""
|
||||
# Validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# set init values
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
)
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
||||
|
||||
# decoder-only models should use left-padding for generation (can't be checked with `trace=True`)
|
||||
if not self.config.is_encoder_decoder and not trace:
|
||||
if pad_token_id is not None and jnp.sum(input_ids[:, -1] == pad_token_id) > 0:
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
if model_kwargs.get("encoder_outputs") is None:
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
|
||||
# prepare decoder_input_ids for generation
|
||||
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||
|
||||
# Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
if max_length is None and max_new_tokens is None:
|
||||
warnings.warn(
|
||||
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to "
|
||||
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
|
||||
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
|
||||
"using `max_new_tokens` to control the maximum length of the generation.",
|
||||
UserWarning,
|
||||
)
|
||||
elif max_length is None and max_new_tokens is not None:
|
||||
max_length = max_new_tokens + input_ids_seq_length
|
||||
elif max_length is not None and max_new_tokens is not None:
|
||||
raise ValueError(
|
||||
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
|
||||
" limit to the generated output length. Remove one of those arguments. Please refer to the"
|
||||
" documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
# default to config if still None
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
|
||||
if min_length is not None and min_length > max_length:
|
||||
raise ValueError(
|
||||
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
|
||||
f"length ({max_length})"
|
||||
)
|
||||
if input_ids_seq_length >= max_length:
|
||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||
logger.warning(
|
||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||
f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
|
||||
"`max_new_tokens`."
|
||||
)
|
||||
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
|
||||
if not do_sample and num_beams == 1:
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
return self._greedy_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif do_sample and num_beams == 1:
|
||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
return self._sample(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
prng_key,
|
||||
logits_warper=logits_warper,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
elif not do_sample and num_beams > 1:
|
||||
# broadcast input_ids & encoder_outputs
|
||||
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
||||
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
|
||||
)
|
||||
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||
model_kwargs["attention_mask"], num_beams=num_beams
|
||||
)
|
||||
|
||||
logits_processor = self._get_logits_processor(
|
||||
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||
)
|
||||
|
||||
return self._beam_search(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
logits_processor=logits_processor,
|
||||
trace=trace,
|
||||
params=params,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsWarper`]
|
||||
instances used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# init warp parameters
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
top_p = top_p if top_p is not None else self.config.top_p
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
# instantiate warpers list
|
||||
warpers = FlaxLogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if temperature is not None and temperature != 1.0:
|
||||
warpers.append(FlaxTemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
|
||||
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
no_repeat_ngram_size: int,
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
eos_token_id: int,
|
||||
forced_bos_token_id: int,
|
||||
forced_eos_token_id: int,
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`]
|
||||
instances used to modify the scores of the language model head.
|
||||
"""
|
||||
processors = FlaxLogitsProcessorList()
|
||||
|
||||
# init warp parameters
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
forced_bos_token_id = (
|
||||
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||
)
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
|
||||
if forced_bos_token_id is not None:
|
||||
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
return processors
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||
|
||||
# per batch-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = GreedyState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
running_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def greedy_search_cond_fn(state):
|
||||
"""state termination condition fn."""
|
||||
has_reached_max_length = state.cur_len == max_length
|
||||
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||
return ~finish_generation
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply min_length, ...
|
||||
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||
|
||||
next_token = jnp.argmax(logits, axis=-1)
|
||||
|
||||
next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
return GreedyState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
running_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[1] > 1:
|
||||
state = greedy_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||
|
||||
return FlaxGreedySearchOutput(sequences=state.sequences)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
prng_key: Optional[jnp.ndarray] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||
|
||||
# per batch-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = SampleState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
running_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
prng_key=prng_key,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def sample_search_cond_fn(state):
|
||||
"""state termination condition fn."""
|
||||
has_reached_max_length = state.cur_len == max_length
|
||||
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||
return ~finish_generation
|
||||
|
||||
def sample_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply min_length, ...
|
||||
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||
# apply top_p, top_k, temperature
|
||||
logits = logits_warper(logits, logits, state.cur_len)
|
||||
|
||||
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
||||
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
|
||||
return SampleState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
running_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
prng_key=prng_key_next,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[1] > 1:
|
||||
state = sample_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
|
||||
return FlaxSampleOutput(sequences=state.sequences)
|
||||
|
||||
def _beam_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
):
|
||||
"""
|
||||
This beam search function is heavily inspired by Flax's official example:
|
||||
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
||||
"""
|
||||
|
||||
def flatten_beam_dim(tensor):
|
||||
"""Flattens the first two dimensions of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
||||
|
||||
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
||||
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
||||
|
||||
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
||||
"""
|
||||
Gathers the beam slices indexed by beam_indices into new beam array.
|
||||
"""
|
||||
batch_indices = jnp.reshape(
|
||||
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
||||
)
|
||||
|
||||
def gather_fn(tensor):
|
||||
# ignore scalars (e.g. cache index)
|
||||
if tensor.ndim == 0:
|
||||
return tensor
|
||||
else:
|
||||
return tensor[batch_indices, beam_indices]
|
||||
|
||||
return jax.tree_util.tree_map(gather_fn, nested)
|
||||
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
batch_size, num_beams, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch,beam-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
||||
|
||||
# per batch,beam-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
||||
|
||||
# per batch,beam-item score, logprobs
|
||||
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
||||
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
||||
|
||||
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||
model = self.decode if self.config.is_encoder_decoder else self
|
||||
|
||||
# flatten beam dim
|
||||
if "encoder_outputs" in model_kwargs:
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
||||
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
||||
)
|
||||
if "attention_mask" in model_kwargs:
|
||||
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = BeamSearchState(
|
||||
cur_len=cur_len,
|
||||
running_sequences=running_sequences,
|
||||
running_scores=running_scores,
|
||||
sequences=sequences,
|
||||
scores=scores,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def beam_search_cond_fn(state):
|
||||
"""beam search state termination condition fn."""
|
||||
|
||||
# 1. is less than max length?
|
||||
not_max_length_yet = state.cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
|
||||
worst_finished_score = jnp.where(
|
||||
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
||||
)
|
||||
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
||||
|
||||
# 3. is there still a beam that has not finished?
|
||||
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
||||
|
||||
return not_max_length_yet & still_open_beam & improvement_still_possible
|
||||
|
||||
def beam_search_body_fn(state, input_ids_length=1):
|
||||
"""beam search state update fn."""
|
||||
# 1. Forward current tokens
|
||||
# Collect the current position slice along length to feed the fast
|
||||
# autoregressive decoder model. Flatten the beam dimension into batch
|
||||
# dimension for feeding into the model.
|
||||
# unflatten beam dimension
|
||||
# Unflatten beam dimension in attention cache arrays
|
||||
input_token = flatten_beam_dim(
|
||||
lax.dynamic_slice(
|
||||
state.running_sequences,
|
||||
(0, 0, state.cur_len - input_ids_length),
|
||||
(batch_size, num_beams, input_ids_length),
|
||||
)
|
||||
)
|
||||
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
||||
|
||||
logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
|
||||
cache = jax.tree_util.tree_map(
|
||||
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
||||
)
|
||||
|
||||
# adapt logits for FlaxMarianMTModel
|
||||
logits = self._adapt_logits_for_beam_search(logits)
|
||||
|
||||
# 2. Compute log probs
|
||||
# get log probabilities from logits,
|
||||
# process logits with processors (*e.g.* min_length, ...), and
|
||||
# add new logprobs to existing running logprobs scores.
|
||||
log_probs = jax.nn.log_softmax(logits)
|
||||
log_probs = logits_processor(
|
||||
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
||||
)
|
||||
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
||||
vocab_size = log_probs.shape[2]
|
||||
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
||||
|
||||
# 3. Retrieve top-K
|
||||
# Each item in batch has num_beams * vocab_size candidate sequences.
|
||||
# For each item, get the top 2*k candidates with the highest log-
|
||||
# probabilities. We gather the top 2*K beams here so that even if the best
|
||||
# K sequences reach EOS simultaneously, we have another K sequences
|
||||
# remaining to continue the live beam search.
|
||||
# Gather the top 2*K scores from _all_ beams.
|
||||
# Gather 2*k top beams.
|
||||
# Recover the beam index by floor division.
|
||||
# Recover token id by modulo division and expand Id array for broadcasting.
|
||||
# Update sequences for the 2*K top-k new sequences.
|
||||
beams_to_keep = 2 * num_beams
|
||||
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
||||
topk_beam_indices = topk_indices // vocab_size
|
||||
topk_running_sequences = gather_beams(
|
||||
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
||||
)
|
||||
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
||||
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
||||
|
||||
# 4. Check which sequences have ended
|
||||
# Update current sequences:
|
||||
# Did any of these sequences reach an end marker?
|
||||
# To prevent these just finished sequences from being added to the current sequences
|
||||
# set of active beam search sequences, set their log probs to a very large
|
||||
# negative value.
|
||||
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
||||
running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
||||
# 5. Get running sequences scores for next
|
||||
# Determine the top k beam indices (from top 2*k beams) from log probs
|
||||
# and gather top k beams (from top 2*k beams).
|
||||
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
|
||||
next_running_sequences, next_running_scores = gather_beams(
|
||||
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
||||
)
|
||||
|
||||
# 6. Process topk logits
|
||||
# Further process log probs:
|
||||
# - add length penalty
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
||||
beams_in_batch_are_full = (
|
||||
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
||||
& early_stopping
|
||||
)
|
||||
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||
topk_log_probs += add_penalty * np.array(-1.0e7)
|
||||
|
||||
# 7. Get scores, sequences, is sentence finished for next.
|
||||
# Combine sequences, scores, and flags along the beam dimension and compare
|
||||
# new finished sequence scores to existing finished scores and select the
|
||||
# best from the new set of beams
|
||||
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
||||
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
||||
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
||||
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
||||
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
||||
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
||||
)
|
||||
|
||||
# 8. Update model kwargs.
|
||||
# Determine the top k beam indices from the original set of all beams.
|
||||
# With these, gather the top k beam-associated caches.
|
||||
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
||||
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
||||
model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
||||
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||
|
||||
return BeamSearchState(
|
||||
cur_len=state.cur_len + 1,
|
||||
running_scores=next_running_scores,
|
||||
running_sequences=next_running_sequences,
|
||||
scores=next_scores,
|
||||
sequences=next_sequences,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
if input_ids.shape[-1] > 1:
|
||||
state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
||||
|
||||
# Account for the edge-case where there are no finished sequences for a
|
||||
# particular batch item. If so, return running sequences for that batch item.
|
||||
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
||||
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
||||
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
||||
|
||||
# take best beam for each batch
|
||||
sequences = sequences[:, -1]
|
||||
scores = scores[:, -1]
|
||||
|
||||
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
||||
class FlaxGenerationMixin(FlaxGenerationMixin):
|
||||
# warning at import time
|
||||
warnings.warn(
|
||||
"Importing `FlaxGenerationMixin` from `src/transformers/generation_flax_utils.py` is deprecated and will "
|
||||
"be removed in Transformers v5. Import as `from transformers import FlaxGenerationMixin` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -33,7 +33,7 @@ from jax.random import PRNGKey
|
|||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation_flax_utils import FlaxGenerationMixin
|
||||
from .generation import FlaxGenerationMixin
|
||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||
from .utils import (
|
||||
FLAX_WEIGHTS_INDEX_NAME,
|
||||
|
|
|
@ -43,7 +43,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
|
|||
from .activations_tf import get_tf_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation_tf_utils import TFGenerationMixin
|
||||
from .generation import TFGenerationMixin
|
||||
from .tf_utils import shape_list
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
|
|
|
@ -38,7 +38,7 @@ from .activations import get_activation
|
|||
from .configuration_utils import PretrainedConfig
|
||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation_utils import GenerationMixin
|
||||
from .generation import GenerationMixin
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
|
|
|
@ -21,9 +21,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...generation_beam_search import BeamSearchScorer
|
||||
from ...generation_logits_process import LogitsProcessorList
|
||||
from ...generation_stopping_criteria import StoppingCriteriaList
|
||||
from ...generation import BeamSearchScorer, LogitsProcessorList, StoppingCriteriaList
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
|
@ -925,8 +923,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||
**model_kwargs
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
|
||||
documentation for more information on how to set other generate input parameters.
|
||||
Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
|
||||
for more information on how to set other generate input parameters.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
@ -960,14 +958,14 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||
to be set to `False` if used while training with distributed backend.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`]`
|
||||
function, where we set `num_return_sequences` to `num_beams`.
|
||||
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
|
||||
where we set `num_return_sequences` to `num_beams`.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
n_docs (`int`, *optional*, defaults to `config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
kwargs:
|
||||
Additional kwargs will be passed to [`~generation_utils.GenerationMixin.generate`].
|
||||
Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].
|
||||
|
||||
Return:
|
||||
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
|
@ -1486,8 +1484,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
enabled.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`] function,
|
||||
where we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
|
||||
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
|
||||
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
|
||||
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
n_docs (`int`, *optional*, defaults to `config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
|
|
|
@ -1073,8 +1073,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
Number of beams for beam search. 1 means no beam search.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`] function,
|
||||
where we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
|
||||
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
|
||||
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
|
||||
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
n_docs (`int`, *optional*, defaults to `config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
|
@ -1676,8 +1676,8 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
|
|||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
Implements RAG sequence "thorough" decoding. Read the [`~generation_utils.GenerationMixin.generate`]`
|
||||
documentation for more information on how to set other generate input parameters
|
||||
Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
|
||||
for more information on how to set other generate input parameters
|
||||
|
||||
Args:
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
@ -1705,14 +1705,14 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
|
|||
to be set to `False` if used while training with distributed backend.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the `generator`'s `[`~generation_utils.GenerationMixin.generate`]`
|
||||
function, where we set `num_return_sequences` to `num_beams`.
|
||||
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
|
||||
where we set `num_return_sequences` to `num_beams`.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
n_docs (`int`, *optional*, defaults to `config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
kwargs:
|
||||
Additional kwargs will be passed to [`~generation_utils.GenerationMixin.generate`]
|
||||
Additional kwargs will be passed to [`~generation.GenerationMixin.generate`]
|
||||
|
||||
Return:
|
||||
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
|
||||
|
|
|
@ -94,7 +94,7 @@ class ImageToTextPipeline(Pipeline):
|
|||
def _forward(self, model_inputs, generate_kwargs=None):
|
||||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
# FIXME: We need to pop here due to a difference in how `generation_utils.py` and `generation_tf_utils.py`
|
||||
# FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
|
||||
# parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
|
||||
# the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
|
||||
# in the `_prepare_model_inputs` method.
|
||||
|
|
|
@ -34,7 +34,7 @@ class Text2TextGenerationPipeline(Pipeline):
|
|||
up-to-date list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?filter=text2text-generation). For a list of available
|
||||
parameters, see the [following
|
||||
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate)
|
||||
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
|
||||
|
||||
Usage:
|
||||
|
||||
|
@ -206,7 +206,7 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
|
|||
currently, '*bart-large-cnn*', '*t5-small*', '*t5-base*', '*t5-large*', '*t5-3b*', '*t5-11b*'. See the up-to-date
|
||||
list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list
|
||||
of available parameters, see the [following
|
||||
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate)
|
||||
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
|
||||
|
||||
Usage:
|
||||
|
||||
|
@ -274,7 +274,7 @@ class TranslationPipeline(Text2TextGenerationPipeline):
|
|||
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
|
||||
up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation).
|
||||
For a list of available parameters, see the [following
|
||||
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate)
|
||||
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
|
||||
|
||||
Usage:
|
||||
|
||||
|
|
|
@ -17,6 +17,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
|||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
|
|
@ -80,34 +80,6 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Constraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ConstraintListState(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DisjunctiveConstraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PhrasalConstraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeamScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -129,6 +101,27 @@ class ConstrainedBeamSearchScorer(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Constraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ConstraintListState(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DisjunctiveConstraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -143,6 +136,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class HammingDiversityLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -178,6 +178,20 @@ class LogitsWarper(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MaxLengthCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MaxTimeCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MinLengthLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -199,6 +213,13 @@ class NoRepeatNGramLogitsProcessor(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PhrasalConstraint(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PrefixConstrainedLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -213,6 +234,20 @@ class RepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StoppingCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StoppingCriteriaList(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -241,34 +276,6 @@ class TypicalLogitsWarper(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MaxLengthCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class MaxTimeCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StoppingCriteria(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class StoppingCriteriaList(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||
|
||||
|
|
|
@ -31,6 +31,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
|||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
|
|
@ -490,7 +490,7 @@ from transformers.utils import cached_property
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers.testing_utils import require_torch
|
|||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_constraints import DisjunctiveConstraint
|
||||
from transformers.generation import DisjunctiveConstraint
|
||||
|
||||
|
||||
@require_torch
|
|
@ -25,8 +25,13 @@ from ..test_modeling_common import floats_tensor, ids_tensor
|
|||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from transformers.generation import (
|
||||
BeamHypotheses,
|
||||
BeamSearchScorer,
|
||||
ConstrainedBeamSearchScorer,
|
||||
DisjunctiveConstraint,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
|
||||
|
||||
class BeamSearchTester:
|
|
@ -27,7 +27,7 @@ from ..test_modeling_flax_common import ids_tensor
|
|||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.generation_flax_logits_process import (
|
||||
from transformers.generation import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
|
@ -26,7 +26,7 @@ if is_torch_available():
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
from transformers.generation import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
|
@ -25,7 +25,7 @@ from ..test_modeling_common import ids_tensor
|
|||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_stopping_criteria import (
|
||||
from transformers.generation import (
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
MaxTimeCriteria,
|
|
@ -26,7 +26,7 @@ from transformers.testing_utils import require_tf
|
|||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.generation_tf_logits_process import (
|
||||
from transformers.generation import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
|
@ -42,32 +42,34 @@ if is_torch_available():
|
|||
pipeline,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from transformers.generation_beam_search import BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation_stopping_criteria import MaxLengthCriteria, StoppingCriteria, StoppingCriteriaList
|
||||
from transformers.generation_utils import (
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
BeamSearchScorer,
|
||||
ConstrainedBeamSearchScorer,
|
||||
DisjunctiveConstraint,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
MaxLengthCriteria,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PhrasalConstraint,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SampleDecoderOnlyOutput,
|
||||
SampleEncoderDecoderOutput,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ from transformers import BartConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import timeout_decorator # noqa
|
|||
from transformers import BartConfig, BartTokenizer, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import BertConfig, is_torch_available
|
|||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
from transformers import BertGenerationConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import unittest
|
|||
from transformers import BigBirdPegasusConfig, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import BlenderbotConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import timeout_decorator # noqa
|
|||
from transformers import BlenderbotConfig, is_flax_available
|
||||
from transformers.testing_utils import jax_device, require_flax, slow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import BlenderbotSmallConfig, is_torch_available
|
|||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import timeout_decorator # noqa
|
|||
from transformers import BlenderbotSmallConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import unittest
|
|||
from transformers import BloomConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import unittest
|
|||
from transformers import CodeGenConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import ConditionalDetrConfig, is_timm_available, is_vision_ava
|
|||
from transformers.testing_utils import require_timm, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
from transformers import CTRLConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from tests.test_modeling_common import floats_tensor, ids_tensor, random_attenti
|
|||
from transformers import Data2VecTextConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import unittest
|
|||
from transformers import DecisionTransformerConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from transformers import DeformableDetrConfig, is_timm_available, is_vision_avai
|
|||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_timm, require_torch_gpu, require_vision, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import DetrConfig, is_timm_available, is_vision_available
|
|||
from transformers.testing_utils import require_timm, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
|
||||
|
|
|
@ -271,7 +271,7 @@ class FlaxEncoderDecoderMixin:
|
|||
eos_token_id = enc_dec_model.config.decoder.eos_token_id
|
||||
decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id
|
||||
|
||||
# Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
|
||||
# Copied from generation.utils (GPT2 doesn't have `pad_token_id`)
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
pad_token_id = eos_token_id
|
||||
if decoder_start_token_id is None:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import ErnieConfig, is_torch_available
|
|||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import FSMTConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import transformers
|
|||
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import unittest
|
|||
from transformers import GPT2Config, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import transformers
|
|||
from transformers import GPT2Tokenizer, GPTNeoConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import GPTNeoConfig, is_torch_available
|
|||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import transformers
|
|||
from transformers import GPT2Tokenizer, GPTJConfig, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, tooslow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import unittest
|
|||
from transformers import GPTJConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from transformers import ImageGPTConfig
|
|||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
|
|
|
@ -24,7 +24,7 @@ from transformers.models.auto import get_values
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
from transformers import LiltConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from transformers.testing_utils import (
|
|||
slow,
|
||||
)
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers.models.auto import get_values
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import M2M100Config, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import MarianConfig, is_flax_available
|
|||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from transformers import MarianConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import MBartConfig, is_flax_available
|
|||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import MBartConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from transformers import MvpConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import NezhaConfig, is_torch_available
|
|||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import timeout_decorator # noqa
|
|||
from transformers import OPTConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, require_sentencepiece, slow
|
||||
|
||||
from ...generation.test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import timeout_decorator # noqa
|
|||
from transformers import OPTConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import PegasusConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ..mbart.test_modeling_mbart import AbstractSeq2SeqIntegrationTest
|
||||
|
|
|
@ -24,7 +24,7 @@ from transformers import is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import PLBartConfig, is_torch_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import unittest
|
|||
from transformers import ProphetNetConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from transformers.testing_utils import (
|
|||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from copy import deepcopy
|
|||
from transformers import RobertaConfig, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
|
|
@ -307,7 +307,7 @@ class FlaxEncoderDecoderMixin:
|
|||
eos_token_id = enc_dec_model.config.decoder.eos_token_id
|
||||
decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id
|
||||
|
||||
# Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
|
||||
# Copied from generation.utils (GPT2 doesn't have `pad_token_id`)
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
pad_token_id = eos_token_id
|
||||
if decoder_start_token_id is None:
|
||||
|
|
|
@ -32,7 +32,7 @@ from transformers.testing_utils import (
|
|||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_generation_utils import GenerationTesterMixin
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue