Generate: TF uses `GenerationConfig` as the basis for `.generate()` parametrization (#20994)
This commit is contained in:
parent
3b309818e7
commit
a6c850e4f4
File diff suppressed because it is too large
Load Diff
|
@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
|
|||
from .activations_tf import get_tf_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import TFGenerationMixin
|
||||
from .generation import GenerationConfig, TFGenerationMixin
|
||||
from .tf_utils import shape_list
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
|
@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
# Save config and origin of the pretrained weights if given in model
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
||||
self._set_save_spec(self.serving.input_signature[0])
|
||||
|
||||
|
@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def can_generate(self) -> bool:
|
||||
"""
|
||||
Returns whether this model can generate sequences with `.generate()`.
|
||||
|
||||
Returns:
|
||||
`bool`: Whether this model can generate sequences with `.generate()`.
|
||||
"""
|
||||
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
|
||||
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_input_embeddings(self) -> tf.keras.layers.Layer:
|
||||
"""
|
||||
Returns the model's input embeddings layer.
|
||||
|
@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
# If it is a model with generation capabilities, attempt to load the generation config
|
||||
if model.can_generate():
|
||||
try:
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
except OSError:
|
||||
logger.info(
|
||||
"Generation config file not found, using a generation config created from the model config."
|
||||
)
|
||||
pass
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
"""TFRAG model implementation."""
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
context_input_ids=None,
|
||||
context_attention_mask=None,
|
||||
doc_scores=None,
|
||||
max_length=None,
|
||||
min_length=None,
|
||||
early_stopping=None,
|
||||
use_cache=None,
|
||||
num_beams=None,
|
||||
bos_token_id=None,
|
||||
pad_token_id=None,
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
bad_words_ids=None,
|
||||
num_return_sequences=None,
|
||||
decoder_start_token_id=None,
|
||||
n_docs=None,
|
||||
output_scores=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict_in_generate=None,
|
||||
**model_kwargs
|
||||
generation_config=None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Implements TFRAG token decoding.
|
||||
|
@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
|
||||
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
|
||||
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
min_length (`int`, *optional*, defaults to 10):
|
||||
The minimum length of the sequence to be generated.
|
||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
|
||||
not.
|
||||
use_cache: (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
|
||||
speed up decoding.
|
||||
pad_token_id (`int`, *optional*):
|
||||
The id of the *padding* token.
|
||||
bos_token_id (`int`, *optional*):
|
||||
The id of the *beginning-of-sequence* token.
|
||||
eos_token_id (`int`, *optional*):
|
||||
The id of the *end-of-sequence* token.
|
||||
length_penalty (`float`, *optional*, defaults to 1.0):
|
||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
|
||||
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
|
||||
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
|
||||
while `length_penalty` < 0.0 encourages shorter sequences.
|
||||
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
bad_words_ids(`List[int]`, *optional*):
|
||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||
num_beams (`int`, *optional*, defaults to 1):
|
||||
Number of beams for beam search. 1 means no beam search.
|
||||
num_return_sequences(`int`, *optional*, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch. Note that this
|
||||
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
|
||||
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
|
||||
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
|
||||
n_docs (`int`, *optional*, defaults to `config.n_docs`)
|
||||
Number of documents to retrieve and/or number of documents for which to generate an answer.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more details.
|
||||
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more details.
|
||||
output_scores (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
model_specific_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||
passed to generate matching the attributes of `generation_config` will override them. If
|
||||
`generation_config` is not provided, the default will be used, which had the following loading
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model.
|
||||
|
||||
Return:
|
||||
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
|
||||
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
|
||||
due to the `eos_token_id`.
|
||||
"""
|
||||
# Handle `generation_config` and kwargs that might update it
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
|
||||
# set default parameters
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id
|
||||
if decoder_start_token_id is not None
|
||||
else self.config.generator.decoder_start_token_id
|
||||
)
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
# retrieve docs
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
|
@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
encoder_outputs = encoder(
|
||||
input_ids=context_input_ids,
|
||||
attention_mask=context_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=generation_config.output_attentions,
|
||||
output_hidden_states=generation_config.output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
decoder_input_ids = tf.fill(
|
||||
(batch_size * num_beams, 1),
|
||||
tf.cast(decoder_start_token_id, tf.int32),
|
||||
(batch_size * generation_config.num_beams, 1),
|
||||
tf.cast(generation_config.decoder_start_token_id, tf.int32),
|
||||
)
|
||||
last_hidden_state = encoder_outputs["last_hidden_state"]
|
||||
|
||||
|
@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
return tf.reshape(tensor, new_shape)
|
||||
|
||||
# correctly extend last_hidden_state and attention mask
|
||||
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
|
||||
encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams)
|
||||
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
|
||||
encoder_outputs["last_hidden_state"] = extend_enc_output(
|
||||
last_hidden_state, num_beams=generation_config.num_beams
|
||||
)
|
||||
|
||||
doc_scores = tf.repeat(doc_scores, num_beams, axis=0)
|
||||
doc_scores = tf.repeat(doc_scores, generation_config.num_beams, axis=0)
|
||||
|
||||
# define start_len & additional parameters
|
||||
model_kwargs["doc_scores"] = doc_scores
|
||||
|
@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
model_kwargs["n_docs"] = n_docs
|
||||
|
||||
pre_processor = self._get_logits_processor(
|
||||
repetition_penalty=self.config.repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=self.config.generator.forced_bos_token_id,
|
||||
forced_eos_token_id=self.config.generator.forced_eos_token_id,
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
|
||||
)
|
||||
|
||||
if num_beams == 1:
|
||||
if generation_config.num_beams == 1:
|
||||
return self.greedy_search(
|
||||
input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=generation_config.max_length,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
logits_processor=pre_processor,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
output_attentions=generation_config.output_attentions,
|
||||
output_hidden_states=generation_config.output_hidden_states,
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif num_beams > 1:
|
||||
if num_beams < num_return_sequences:
|
||||
elif generation_config.num_beams > 1:
|
||||
if generation_config.num_beams < generation_config.num_return_sequences:
|
||||
raise ValueError(
|
||||
"Beam search decoding cannot return more sequences than it has beams. Please set "
|
||||
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
|
||||
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
|
||||
f" num_return_sequences, got {generation_config.num_beams} and"
|
||||
f" {generation_config.num_return_sequences} (respectivelly)"
|
||||
)
|
||||
|
||||
def unflatten_beam_dim(tensor):
|
||||
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||
shape = shape_list(tensor)
|
||||
return tf.reshape(tensor, [-1, num_beams] + shape[1:])
|
||||
return tf.reshape(tensor, [-1, generation_config.num_beams] + shape[1:])
|
||||
|
||||
decoder_input_ids = unflatten_beam_dim(decoder_input_ids)
|
||||
model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"])
|
||||
|
@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
|
||||
return self.beam_search(
|
||||
input_ids=decoder_input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=generation_config.max_length,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
logits_processor=pre_processor,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
output_attentions=generation_config.output_attentions,
|
||||
output_hidden_states=generation_config.output_hidden_states,
|
||||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.rag.generator.get_input_embeddings()
|
||||
|
|
|
@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
|
|||
model.train_on_batch(test_batch, test_batch_labels)
|
||||
|
||||
def _test_xla_generate(self, **generate_kwargs):
|
||||
def _generate_and_check_results(model, config, inputs_dict):
|
||||
def _generate_and_check_results(model, inputs_dict):
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs = inputs_dict["input_ids"]
|
||||
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
|
||||
if config.pad_token_id is not None:
|
||||
if model.generation_config.pad_token_id is not None:
|
||||
if config.pad_token_id == 0:
|
||||
new_pad_token = config.pad_token_id + 1
|
||||
new_pad_token = model.generation_config.pad_token_id + 1
|
||||
else:
|
||||
new_pad_token = config.pad_token_id - 1
|
||||
new_pad_token = model.generation_config.pad_token_id - 1
|
||||
else:
|
||||
new_pad_token = None
|
||||
inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token)
|
||||
inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
|
||||
elif "input_features" in inputs_dict:
|
||||
inputs = inputs_dict["input_features"]
|
||||
else:
|
||||
|
@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
|
|||
model = model_class(config)
|
||||
|
||||
if model.supports_xla_generation:
|
||||
_generate_and_check_results(model, config, inputs_dict)
|
||||
_generate_and_check_results(model, inputs_dict)
|
||||
else:
|
||||
with self.assertRaises(ValueError):
|
||||
_generate_and_check_results(model, config, inputs_dict)
|
||||
_generate_and_check_results(model, inputs_dict)
|
||||
|
||||
def test_xla_generate_fast(self):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue