Add output in a dictionary for TF `generate` method (#12139)

* Add output args to greedy search

* Fix critical typo + make style quality

* Handle generate_beam_search

* Add dict_specific tests and fix the placement of encoder outputs

* Add  specific outputs

* Update doc

* Fix typo

* Adjust handling encoder_outputs + Fix generating for T5

* Fix generate for RAG

* Fix handling ouptut_attentions when target_mapping is not None

Take care of situations when target_mapping is provided
as there are 2-tuple of attentions

Change from:
if inputs["output_attentions"]:
    attentions = tuple(tf.transpose(t, perm(2, 3, 0, 1)) for t in attentions)

to:
if inputs["output_attentions"]:
    if inputs["target_mapping"] is not None:
        # when target_mapping is provided, there are 2-tuple of attentions
         attentions = tuple(
             tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
        )
    else:
        attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)

* Rename kwargs to model_kwargs

* make style quality

* Move imports in test_modeling_tf_common.py

Move ModelOutput-related imports in test_modeling_tf_common.py
into the `is_tf_available():` statement.

* Rewrite nested if-statements

* Fix added tests
This commit is contained in:
Daniel Stancl 2021-06-23 11:52:11 +02:00 committed by GitHub
parent d4be498441
commit 26a2e36595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 672 additions and 21 deletions

View File

@ -14,15 +14,323 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from .file_utils import ModelOutput
from .utils import logging
logger = logging.get_logger(__name__)
@dataclass
class TFGreedySearchDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using greedy search.
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with
each tensor of shape :obj:`(batch_size, config.vocab_size)`).
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFGreedySearchEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
:obj:`(batch_size, config.vocab_size)`).
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFSampleDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using sampling.
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with
each tensor of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`).
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(num_return_sequences*batch_size, num_heads, generated_length,
sequence_length)`.
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(num_return_sequences*batch_size, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFSampleEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
:obj:`(batch_size*num_return_sequences, config.vocab_size)`).
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape
:obj:`(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`.
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSearchDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using beam search.
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Final beam scores of the generated ``sequences``.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
:obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
hidden_size)`.
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSearchEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Final beam scores of the generated ``sequences``.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
:obj:`(batch_size*num_beams, config.vocab_size)`).
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
hidden_size)`.
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSampleDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using beam sample.
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Final beam scores of the generated ``sequences``.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
:obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFBeamSampleEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
Args:
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
shorter if all batches finished early due to the :obj:`eos_token_id`.
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Final beam scores of the generated ``sequences``.
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
:obj:`(batch_size*num_beams, config.vocab_size)`).
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size*num_beams, sequence_length, hidden_size)`.
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput]
TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput]
TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput]
TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput]
class TFGenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
@ -67,9 +375,14 @@ class TFGenerationMixin:
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
**model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
@ -139,6 +452,16 @@ class TFGenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more details.
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more details.
output_scores (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
forced_bos_token_id (:obj:`int`, `optional`):
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
@ -149,10 +472,25 @@ class TFGenerationMixin:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:class:`~transformers.file_utils.ModelOutput` or :obj:`tf.Tensor`: A
:class:`~transformers.file_utils.ModelOutput` (if ``return_dict_in_generate=True`` or when
``config.return_dict_in_generate=True``) or a :obj:`tf.Tensor`.
:obj:`tf.Tensor` of :obj:`dtype=tf.int32` and shape :obj:`(batch_size * num_return_sequences,
sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to
:obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`.
If the model is `not` an encoder-decoder model (``model.config.is_encoder_decoder=False``), the
possible :class:`~transformers.file_utils.ModelOutput` types are:
- :class:`~transformers.generation_utils.TFGreedySearchDecoderOnlyOutput`,
- :class:`~transformers.generation_utils.TFSampleDecoderOnlyOutput`,
- :class:`~transformers.generation_utils.TFBeamSearchDecoderOnlyOutput`,
- :class:`~transformers.generation_utils.TFBeamSampleDecoderOnlyOutput`
If the model is an encoder-decoder model (``model.config.is_encoder_decoder=True``), the possible
:class:`~transformers.file_utils.ModelOutput` types are:
- :class:`~transformers.generation_utils.TFGreedySearchEncoderDecoderOutput`,
- :class:`~transformers.generation_utils.TFSampleEncoderDecoderOutput`,
- :class:`~transformers.generation_utils.TFBeamSearchEncoderDecoderOutput`,
- :class:`~transformers.generation_utils.TFBeamSampleEncoderDecoderOutput`
Examples::
@ -229,6 +567,22 @@ class TFGenerationMixin:
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
model_kwargs["output_scores"] = output_scores
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
if self.config.is_encoder_decoder:
model_kwargs["encoder_attentions"] = None
model_kwargs["encoder_hidden_states"] = None
if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
else:
@ -319,7 +673,17 @@ class TFGenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
@ -394,6 +758,8 @@ class TFGenerationMixin:
use_cache=use_cache,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
else:
output = self._generate_no_beam_search(
@ -415,6 +781,8 @@ class TFGenerationMixin:
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
return output
@ -439,8 +807,9 @@ class TFGenerationMixin:
encoder_outputs,
attention_mask,
use_cache,
return_dict_in_generate,
**kwargs
):
) -> Union[TFGreedySearchOutput, TFSampleOutput, tf.Tensor]:
"""
Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated
independently.
@ -452,12 +821,51 @@ class TFGenerationMixin:
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if self.config.is_encoder_decoder:
encoder_attentions = (
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
)
encoder_hidden_states = (
kwargs["encoder_hidden_states"]
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
else None
)
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=kwargs["output_attentions"],
output_hidden_states=kwargs["output_hidden_states"],
)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if kwargs["output_scores"]:
scores += (next_token_logits,)
if kwargs["output_attentions"]:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if kwargs["output_hidden_states"]:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
@ -580,7 +988,45 @@ class TFGenerationMixin:
else:
decoded = input_ids
return decoded
if return_dict_in_generate:
if do_sample:
if self.config.is_encoder_decoder:
return TFSampleEncoderDecoderOutput(
sequences=decoded,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFSampleDecoderOnlyOutput(
sequences=decoded,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
if self.config.is_encoder_decoder:
return TFGreedySearchEncoderDecoderOutput(
sequences=decoded,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFGreedySearchDecoderOnlyOutput(
sequences=decoded,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return decoded
def _generate_beam_search(
self,
@ -608,8 +1054,9 @@ class TFGenerationMixin:
use_cache,
forced_bos_token_id,
forced_eos_token_id,
return_dict_in_generate,
**kwargs,
):
) -> Union[TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
"""Generate sequences for each example with beam search."""
# generated hypotheses
@ -632,6 +1079,22 @@ class TFGenerationMixin:
past = encoder_outputs
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if self.config.is_encoder_decoder:
encoder_attentions = (
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
)
encoder_hidden_states = (
kwargs["encoder_hidden_states"]
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
else None
)
# done sentences
done = [False for _ in range(batch_size)]
@ -639,8 +1102,13 @@ class TFGenerationMixin:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=kwargs["output_attentions"],
output_hidden_states=kwargs["output_hidden_states"],
)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
@ -751,6 +1219,24 @@ class TFGenerationMixin:
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if kwargs["output_scores"]:
scores += (next_token_logits,)
if kwargs["output_attentions"]:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if kwargs["output_hidden_states"]:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# next batch beam content
next_batch_beam = []
@ -911,7 +1397,43 @@ class TFGenerationMixin:
assert (len(hypo) == max_length for hypo in best)
decoded = tf.stack(best)
return decoded
if return_dict_in_generate:
if do_sample and self.config.is_encoder_decoder:
return TFBeamSampleEncoderDecoderOutput(
sequences=decoded,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
elif do_sample and not self.config.is_encoder_decoder:
return TFBeamSampleDecoderOnlyOutput(
sequences=decoded,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
elif self.config.is_encoder_decoder:
return TFBeamSearchEncoderDecoderOutput(
sequences=decoded,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFBeamSearchDecoderOnlyOutput(
sequences=decoded,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return decoded
@staticmethod
def _reorder_cache(past, beam_idx):

View File

@ -1063,7 +1063,11 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
**kwargs
output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
**model_kwargs
):
"""
Implements TFRAG token decoding.
@ -1137,6 +1141,18 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more details.
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more details.
output_scores (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
@ -1167,6 +1183,21 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
else self.config.generator.decoder_start_token_id
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
model_kwargs["output_scores"] = output_scores
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["encoder_attentions"] = None
model_kwargs["encoder_hidden_states"] = None
# retrieve docs
if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
@ -1200,7 +1231,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
batch_size = context_input_ids.shape[0] // n_docs
encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
encoder_outputs = encoder(
input_ids=context_input_ids,
attention_mask=context_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
if return_dict_in_generate:
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
decoder_input_ids = tf.fill(
(batch_size * num_beams, 1),
@ -1238,9 +1281,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
# define start_len & additional parameters
cur_len = 1
vocab_size = self.config.generator.vocab_size
kwargs["doc_scores"] = doc_scores
kwargs["encoder_outputs"] = encoder_outputs
kwargs["n_docs"] = n_docs
model_kwargs["doc_scores"] = doc_scores
model_kwargs["encoder_outputs"] = encoder_outputs
model_kwargs["n_docs"] = n_docs
# not needed. TODO(PVP): change after generate refactor
do_sample = False
@ -1274,7 +1317,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
use_cache=use_cache,
forced_bos_token_id=None,
forced_eos_token_id=None,
**kwargs, # encoder_outputs is here as in Pytorch's version
return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version
)
else:
return self._generate_no_beam_search(
@ -1297,7 +1341,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
use_cache=use_cache,
forced_bos_token_id=None,
forced_eos_token_id=None,
**kwargs, # encoder_outputs is here as in Pytorch's version
return_dict_in_generate=return_dict_in_generate,
**model_kwargs, # encoder_outputs is here as in Pytorch's version
)
def get_input_embeddings(self):

View File

@ -1481,6 +1481,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs, past_key_values = past, None
else:
encoder_outputs, past_key_values = past[0], past[1]
if "encoder_hidden_states" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"])
if "encoder_attentions" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"])
# cut decoder_input_ids if past is used
if past_key_values is not None:

View File

@ -796,7 +796,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
if inputs["output_attentions"]:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if inputs["target_mapping"] is not None:
# when target_mapping is provided, there are 2-tuple of attentions
attentions = tuple(
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
)
else:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not inputs["return_dict"]:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)

View File

@ -61,6 +61,16 @@ if is_tf_available():
TFSharedEmbeddings,
tf_top_k_top_p_filtering,
)
from transformers.generation_tf_utils import (
TFBeamSampleDecoderOnlyOutput,
TFBeamSampleEncoderDecoderOutput,
TFBeamSearchDecoderOnlyOutput,
TFBeamSearchEncoderDecoderOutput,
TFGreedySearchDecoderOnlyOutput,
TFGreedySearchEncoderDecoderOutput,
TFSampleDecoderOnlyOutput,
TFSampleEncoderDecoderOutput,
)
if _tf_gpu_memory_limit is not None:
gpus = tf.config.list_physical_devices("GPU")
@ -1100,6 +1110,37 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_lm_head_model_no_beam_search_generate_dict_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
output_greedy = model.generate(
input_ids,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
output_sample = model.generate(
input_ids,
do_sample=True,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, TFGreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_sample, TFSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_greedy, TFGreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_sample, TFSampleDecoderOnlyOutput)
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None)
@ -1140,6 +1181,39 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_lm_head_model_beam_search_generate_dict_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config)
output_beam_search = model.generate(
input_ids,
num_beams=2,
do_sample=False,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
output_beam_sample = model.generate(
input_ids,
num_beams=2,
do_sample=True,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, TFBeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_beam_sample, TFBeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, TFBeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_beam_sample, TFBeamSampleDecoderOnlyOutput)
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: