Add output in a dictionary for TF `generate` method (#12139)
* Add output args to greedy search * Fix critical typo + make style quality * Handle generate_beam_search * Add dict_specific tests and fix the placement of encoder outputs * Add specific outputs * Update doc * Fix typo * Adjust handling encoder_outputs + Fix generating for T5 * Fix generate for RAG * Fix handling ouptut_attentions when target_mapping is not None Take care of situations when target_mapping is provided as there are 2-tuple of attentions Change from: if inputs["output_attentions"]: attentions = tuple(tf.transpose(t, perm(2, 3, 0, 1)) for t in attentions) to: if inputs["output_attentions"]: if inputs["target_mapping"] is not None: # when target_mapping is provided, there are 2-tuple of attentions attentions = tuple( tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions ) else: attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) * Rename kwargs to model_kwargs * make style quality * Move imports in test_modeling_tf_common.py Move ModelOutput-related imports in test_modeling_tf_common.py into the `is_tf_available():` statement. * Rewrite nested if-statements * Fix added tests
This commit is contained in:
parent
d4be498441
commit
26a2e36595
|
@ -14,15 +14,323 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFGreedySearchDecoderOnlyOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with
|
||||
each tensor of shape :obj:`(batch_size, config.vocab_size)`).
|
||||
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFGreedySearchEncoderDecoderOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
|
||||
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
|
||||
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
|
||||
:obj:`(batch_size, config.vocab_size)`).
|
||||
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFSampleDecoderOnlyOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of decoder-only generation models using sampling.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with
|
||||
each tensor of shape :obj:`(batch_size*num_return_sequences, config.vocab_size)`).
|
||||
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(num_return_sequences*batch_size, num_heads, generated_length,
|
||||
sequence_length)`.
|
||||
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(num_return_sequences*batch_size, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFSampleEncoderDecoderOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
|
||||
the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
|
||||
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
|
||||
:obj:`(batch_size*num_return_sequences, config.vocab_size)`).
|
||||
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape
|
||||
:obj:`(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`.
|
||||
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size*num_return_sequences, sequence_length, hidden_size)`.
|
||||
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
|
||||
sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFBeamSearchDecoderOnlyOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of decoder-only generation models using beam search.
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Final beam scores of the generated ``sequences``.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
|
||||
:obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
|
||||
hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
sequences_scores: Optional[tf.Tensor] = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFBeamSearchEncoderDecoderOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
|
||||
of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
|
||||
attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Final beam scores of the generated ``sequences``.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
|
||||
:obj:`(batch_size*num_beams, config.vocab_size)`).
|
||||
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`.
|
||||
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, generated_length,
|
||||
sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
|
||||
hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
sequences_scores: Optional[tf.Tensor] = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFBeamSampleDecoderOnlyOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of decoder-only generation models using beam sample.
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_return_sequences, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Final beam scores of the generated ``sequences``.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. :obj:`(max_length-input_ids.shape[-1],)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
|
||||
:obj:`(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
|
||||
attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
sequences_scores: Optional[tf.Tensor] = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TFBeamSampleEncoderDecoderOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
|
||||
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
|
||||
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
|
||||
|
||||
Args:
|
||||
sequences (:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or
|
||||
shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
sequences_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequence)`, `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Final beam scores of the generated ``sequences``.
|
||||
scores (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_scores=True`` is passed or when ``config.output_scores=True``):
|
||||
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
|
||||
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam
|
||||
. :obj:`(max_length-1,)`-shaped tuple of :obj:`tf.Tensor` with each tensor of shape
|
||||
:obj:`(batch_size*num_beams, config.vocab_size)`).
|
||||
encoder_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for each layer of the decoder) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
encoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape :obj:`(batch_size*num_beams, sequence_length, hidden_size)`.
|
||||
decoder_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
|
||||
cross_attentions (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
decoder_hidden_states (:obj:`tuple(tuple(tf.Tensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: tf.Tensor = None
|
||||
sequences_scores: Optional[tf.Tensor] = None
|
||||
scores: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
|
||||
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
|
||||
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
|
||||
|
||||
|
||||
TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput]
|
||||
TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput]
|
||||
TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput]
|
||||
TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput]
|
||||
|
||||
|
||||
class TFGenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
|
@ -67,9 +375,14 @@ class TFGenerationMixin:
|
|||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
output_scores=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict_in_generate=None,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
**model_kwargs,
|
||||
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
||||
|
@ -139,6 +452,16 @@ class TFGenerationMixin:
|
|||
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||
for more details.
|
||||
output_scores (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
forced_bos_token_id (:obj:`int`, `optional`):
|
||||
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
|
||||
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token
|
||||
|
@ -149,10 +472,25 @@ class TFGenerationMixin:
|
|||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
Return:
|
||||
:class:`~transformers.file_utils.ModelOutput` or :obj:`tf.Tensor`: A
|
||||
:class:`~transformers.file_utils.ModelOutput` (if ``return_dict_in_generate=True`` or when
|
||||
``config.return_dict_in_generate=True``) or a :obj:`tf.Tensor`.
|
||||
|
||||
:obj:`tf.Tensor` of :obj:`dtype=tf.int32` and shape :obj:`(batch_size * num_return_sequences,
|
||||
sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to
|
||||
:obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`.
|
||||
If the model is `not` an encoder-decoder model (``model.config.is_encoder_decoder=False``), the
|
||||
possible :class:`~transformers.file_utils.ModelOutput` types are:
|
||||
|
||||
- :class:`~transformers.generation_utils.TFGreedySearchDecoderOnlyOutput`,
|
||||
- :class:`~transformers.generation_utils.TFSampleDecoderOnlyOutput`,
|
||||
- :class:`~transformers.generation_utils.TFBeamSearchDecoderOnlyOutput`,
|
||||
- :class:`~transformers.generation_utils.TFBeamSampleDecoderOnlyOutput`
|
||||
|
||||
If the model is an encoder-decoder model (``model.config.is_encoder_decoder=True``), the possible
|
||||
:class:`~transformers.file_utils.ModelOutput` types are:
|
||||
|
||||
- :class:`~transformers.generation_utils.TFGreedySearchEncoderDecoderOutput`,
|
||||
- :class:`~transformers.generation_utils.TFSampleEncoderDecoderOutput`,
|
||||
- :class:`~transformers.generation_utils.TFBeamSearchEncoderDecoderOutput`,
|
||||
- :class:`~transformers.generation_utils.TFBeamSampleEncoderDecoderOutput`
|
||||
|
||||
Examples::
|
||||
|
||||
|
@ -229,6 +567,22 @@ class TFGenerationMixin:
|
|||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
model_kwargs["output_scores"] = output_scores
|
||||
model_kwargs["output_attentions"] = output_attentions
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
if self.config.is_encoder_decoder:
|
||||
model_kwargs["encoder_attentions"] = None
|
||||
model_kwargs["encoder_hidden_states"] = None
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = shape_list(input_ids)[0] # overridden by the input batch_size
|
||||
else:
|
||||
|
@ -319,7 +673,17 @@ class TFGenerationMixin:
|
|||
# get encoder and store encoder outputs
|
||||
encoder = self.get_encoder()
|
||||
|
||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
||||
encoder_outputs = encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
if return_dict_in_generate:
|
||||
if output_attentions:
|
||||
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
|
||||
if output_hidden_states:
|
||||
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
|
||||
|
||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||
if num_return_sequences > 1 or num_beams > 1:
|
||||
|
@ -394,6 +758,8 @@ class TFGenerationMixin:
|
|||
use_cache=use_cache,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
|
@ -415,6 +781,8 @@ class TFGenerationMixin:
|
|||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -439,8 +807,9 @@ class TFGenerationMixin:
|
|||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
return_dict_in_generate,
|
||||
**kwargs
|
||||
):
|
||||
) -> Union[TFGreedySearchOutput, TFSampleOutput, tf.Tensor]:
|
||||
"""
|
||||
Generate sequences for each example without beam search (num_beams == 1). All returned sequences are generated
|
||||
independently.
|
||||
|
@ -452,12 +821,51 @@ class TFGenerationMixin:
|
|||
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
||||
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if self.config.is_encoder_decoder:
|
||||
encoder_attentions = (
|
||||
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
kwargs["encoder_hidden_states"]
|
||||
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
|
||||
else None
|
||||
)
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
|
||||
)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=kwargs["output_attentions"],
|
||||
output_hidden_states=kwargs["output_hidden_states"],
|
||||
)
|
||||
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if kwargs["output_scores"]:
|
||||
scores += (next_token_logits,)
|
||||
if kwargs["output_attentions"]:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if kwargs["output_hidden_states"]:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
|
@ -580,7 +988,45 @@ class TFGenerationMixin:
|
|||
else:
|
||||
decoded = input_ids
|
||||
|
||||
return decoded
|
||||
if return_dict_in_generate:
|
||||
if do_sample:
|
||||
if self.config.is_encoder_decoder:
|
||||
return TFSampleEncoderDecoderOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFSampleDecoderOnlyOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
if self.config.is_encoder_decoder:
|
||||
return TFGreedySearchEncoderDecoderOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFGreedySearchDecoderOnlyOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return decoded
|
||||
|
||||
def _generate_beam_search(
|
||||
self,
|
||||
|
@ -608,8 +1054,9 @@ class TFGenerationMixin:
|
|||
use_cache,
|
||||
forced_bos_token_id,
|
||||
forced_eos_token_id,
|
||||
return_dict_in_generate,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
|
||||
"""Generate sequences for each example with beam search."""
|
||||
|
||||
# generated hypotheses
|
||||
|
@ -632,6 +1079,22 @@ class TFGenerationMixin:
|
|||
past = encoder_outputs
|
||||
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
||||
cross_attentions = () if (return_dict_in_generate and kwargs["output_attentions"]) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and kwargs["output_hidden_states"]) else None
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if self.config.is_encoder_decoder:
|
||||
encoder_attentions = (
|
||||
kwargs["encoder_attentions"] if (return_dict_in_generate and kwargs["encoder_attentions"]) else None
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
kwargs["encoder_hidden_states"]
|
||||
if (return_dict_in_generate and kwargs["encoder_hidden_states"])
|
||||
else None
|
||||
)
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
|
@ -639,8 +1102,13 @@ class TFGenerationMixin:
|
|||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=kwargs["output_attentions"],
|
||||
output_hidden_states=kwargs["output_hidden_states"],
|
||||
)
|
||||
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._use_cache(outputs, use_cache):
|
||||
|
@ -751,6 +1219,24 @@ class TFGenerationMixin:
|
|||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if kwargs["output_scores"]:
|
||||
scores += (next_token_logits,)
|
||||
if kwargs["output_attentions"]:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if kwargs["output_hidden_states"]:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# next batch beam content
|
||||
next_batch_beam = []
|
||||
|
||||
|
@ -911,7 +1397,43 @@ class TFGenerationMixin:
|
|||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = tf.stack(best)
|
||||
|
||||
return decoded
|
||||
if return_dict_in_generate:
|
||||
if do_sample and self.config.is_encoder_decoder:
|
||||
return TFBeamSampleEncoderDecoderOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
elif do_sample and not self.config.is_encoder_decoder:
|
||||
return TFBeamSampleDecoderOnlyOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
elif self.config.is_encoder_decoder:
|
||||
return TFBeamSearchEncoderDecoderOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return TFBeamSearchDecoderOnlyOutput(
|
||||
sequences=decoded,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return decoded
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
|
|
@ -1063,7 +1063,11 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
num_return_sequences=None,
|
||||
decoder_start_token_id=None,
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
output_scores=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict_in_generate=None,
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
Implements TFRAG token decoding.
|
||||
|
@ -1137,6 +1141,18 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||
n_docs (:obj:`int`, `optional`, defaults to :obj:`config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
|
||||
for more details.
|
||||
output_scores (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
model_specific_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
Return:
|
||||
:obj:`tf.Tensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
|
@ -1167,6 +1183,21 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
else self.config.generator.decoder_start_token_id
|
||||
)
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
model_kwargs["output_scores"] = output_scores
|
||||
model_kwargs["output_attentions"] = output_attentions
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
model_kwargs["encoder_attentions"] = None
|
||||
model_kwargs["encoder_hidden_states"] = None
|
||||
|
||||
# retrieve docs
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
|
@ -1200,7 +1231,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
batch_size = context_input_ids.shape[0] // n_docs
|
||||
|
||||
encoder = self.rag.generator.get_encoder()
|
||||
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
|
||||
encoder_outputs = encoder(
|
||||
input_ids=context_input_ids,
|
||||
attention_mask=context_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if output_attentions:
|
||||
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
|
||||
if output_hidden_states:
|
||||
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
|
||||
|
||||
decoder_input_ids = tf.fill(
|
||||
(batch_size * num_beams, 1),
|
||||
|
@ -1238,9 +1281,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
# define start_len & additional parameters
|
||||
cur_len = 1
|
||||
vocab_size = self.config.generator.vocab_size
|
||||
kwargs["doc_scores"] = doc_scores
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
kwargs["n_docs"] = n_docs
|
||||
model_kwargs["doc_scores"] = doc_scores
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
model_kwargs["n_docs"] = n_docs
|
||||
|
||||
# not needed. TODO(PVP): change after generate refactor
|
||||
do_sample = False
|
||||
|
@ -1274,7 +1317,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
use_cache=use_cache,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
**kwargs, # encoder_outputs is here as in Pytorch's version
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs, # encoder_outputs is here as in Pytorch's version
|
||||
)
|
||||
else:
|
||||
return self._generate_no_beam_search(
|
||||
|
@ -1297,7 +1341,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
use_cache=use_cache,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
**kwargs, # encoder_outputs is here as in Pytorch's version
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs, # encoder_outputs is here as in Pytorch's version
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
|
|
@ -1481,6 +1481,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
encoder_outputs, past_key_values = past, None
|
||||
else:
|
||||
encoder_outputs, past_key_values = past[0], past[1]
|
||||
if "encoder_hidden_states" in kwargs:
|
||||
encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"])
|
||||
if "encoder_attentions" in kwargs:
|
||||
encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"])
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values is not None:
|
||||
|
|
|
@ -796,7 +796,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
else:
|
||||
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
|
||||
if inputs["output_attentions"]:
|
||||
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||
if inputs["target_mapping"] is not None:
|
||||
# when target_mapping is provided, there are 2-tuple of attentions
|
||||
attentions = tuple(
|
||||
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
|
||||
)
|
||||
else:
|
||||
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
|
||||
|
|
|
@ -61,6 +61,16 @@ if is_tf_available():
|
|||
TFSharedEmbeddings,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_tf_utils import (
|
||||
TFBeamSampleDecoderOnlyOutput,
|
||||
TFBeamSampleEncoderDecoderOutput,
|
||||
TFBeamSearchDecoderOnlyOutput,
|
||||
TFBeamSearchEncoderDecoderOutput,
|
||||
TFGreedySearchDecoderOnlyOutput,
|
||||
TFGreedySearchEncoderDecoderOutput,
|
||||
TFSampleDecoderOnlyOutput,
|
||||
TFSampleEncoderDecoderOutput,
|
||||
)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
|
@ -1100,6 +1110,37 @@ class TFModelTesterMixin:
|
|||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_lm_head_model_no_beam_search_generate_dict_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
output_greedy = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
output_sample = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_greedy, TFGreedySearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_sample, TFSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_greedy, TFGreedySearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_sample, TFSampleDecoderOnlyOutput)
|
||||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
|
@ -1140,6 +1181,39 @@ class TFModelTesterMixin:
|
|||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_lm_head_model_beam_search_generate_dict_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
output_beam_search = model.generate(
|
||||
input_ids,
|
||||
num_beams=2,
|
||||
do_sample=False,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
output_beam_sample = model.generate(
|
||||
input_ids,
|
||||
num_beams=2,
|
||||
do_sample=True,
|
||||
output_scores=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertIsInstance(output_beam_search, TFBeamSearchEncoderDecoderOutput)
|
||||
self.assertIsInstance(output_beam_sample, TFBeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertIsInstance(output_beam_search, TFBeamSearchDecoderOnlyOutput)
|
||||
self.assertIsInstance(output_beam_sample, TFBeamSampleDecoderOnlyOutput)
|
||||
|
||||
def test_loss_computation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
|
|
Loading…
Reference in New Issue