Generate: TF uses `GenerationConfig` as the basis for `.generate()` parametrization (#20994)

This commit is contained in:
Joao Gante 2023-01-04 18:23:20 +00:00 committed by GitHub
parent 3b309818e7
commit a6c850e4f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 440 additions and 574 deletions

File diff suppressed because it is too large Load Diff

View File

@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import TFGenerationMixin
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])
@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
raise NotImplementedError
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
def get_input_embeddings(self) -> tf.keras.layers.Layer:
"""
Returns the model's input embeddings layer.
@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
" to use it for predictions and inference."
)
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,

View File

@ -15,6 +15,7 @@
"""TFRAG model implementation."""
import copy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
max_length=None,
min_length=None,
early_stopping=None,
use_cache=None,
num_beams=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
bad_words_ids=None,
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
**model_kwargs
generation_config=None,
**kwargs
):
"""
Implements TFRAG token decoding.
@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
not.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
due to the `eos_token_id`.
"""
# Handle `generation_config` and kwargs that might update it
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.config.generator.decoder_start_token_id
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# retrieve docs
if self.retriever is not None and context_input_ids is None:
@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
encoder_outputs = encoder(
input_ids=context_input_ids,
attention_mask=context_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
return_dict=True,
)
decoder_input_ids = tf.fill(
(batch_size * num_beams, 1),
tf.cast(decoder_start_token_id, tf.int32),
(batch_size * generation_config.num_beams, 1),
tf.cast(generation_config.decoder_start_token_id, tf.int32),
)
last_hidden_state = encoder_outputs["last_hidden_state"]
@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return tf.reshape(tensor, new_shape)
# correctly extend last_hidden_state and attention mask
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams)
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(
last_hidden_state, num_beams=generation_config.num_beams
)
doc_scores = tf.repeat(doc_scores, num_beams, axis=0)
doc_scores = tf.repeat(doc_scores, generation_config.num_beams, axis=0)
# define start_len & additional parameters
model_kwargs["doc_scores"] = doc_scores
@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
model_kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor(
repetition_penalty=self.config.repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=self.config.generator.forced_bos_token_id,
forced_eos_token_id=self.config.generator.forced_eos_token_id,
generation_config=generation_config,
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
)
if num_beams == 1:
if generation_config.num_beams == 1:
return self.greedy_search(
input_ids=decoder_input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
logits_processor=pre_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
elif num_beams > 1:
if num_beams < num_return_sequences:
elif generation_config.num_beams > 1:
if generation_config.num_beams < generation_config.num_return_sequences:
raise ValueError(
"Beam search decoding cannot return more sequences than it has beams. Please set "
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f" num_return_sequences, got {generation_config.num_beams} and"
f" {generation_config.num_return_sequences} (respectivelly)"
)
def unflatten_beam_dim(tensor):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
shape = shape_list(tensor)
return tf.reshape(tensor, [-1, num_beams] + shape[1:])
return tf.reshape(tensor, [-1, generation_config.num_beams] + shape[1:])
decoder_input_ids = unflatten_beam_dim(decoder_input_ids)
model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"])
@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return self.beam_search(
input_ids=decoder_input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
logits_processor=pre_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
raise ValueError(
f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
)
def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings()

View File

@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
def _generate_and_check_results(model, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
if config.pad_token_id is not None:
if model.generation_config.pad_token_id is not None:
if config.pad_token_id == 0:
new_pad_token = config.pad_token_id + 1
new_pad_token = model.generation_config.pad_token_id + 1
else:
new_pad_token = config.pad_token_id - 1
new_pad_token = model.generation_config.pad_token_id - 1
else:
new_pad_token = None
inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token)
inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
elif "input_features" in inputs_dict:
inputs = inputs_dict["input_features"]
else:
@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
model = model_class(config)
if model.supports_xla_generation:
_generate_and_check_results(model, config, inputs_dict)
_generate_and_check_results(model, inputs_dict)
else:
with self.assertRaises(ValueError):
_generate_and_check_results(model, config, inputs_dict)
_generate_and_check_results(model, inputs_dict)
def test_xla_generate_fast(self):
"""