Generate: TF contrastive search with XLA support (#20050)

* Add contrastive search
This commit is contained in:
Joao Gante 2022-11-07 10:54:29 +00:00 committed by GitHub
parent 504db92e7d
commit a0f8674303
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 770 additions and 46 deletions

View File

@ -38,6 +38,7 @@ from .generation_tf_logits_process import (
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .modeling_tf_outputs import TFCausalLMOutputWithPast, TFSeq2SeqLMOutput
from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
@ -345,15 +346,99 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput):
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFContrastiveSearchDecoderOnlyOutput(ModelOutput):
"""
Base class for outputs of decoder-only generation models using contrastive search.
Args:
sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each
generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFContrastiveSearchEncoderDecoderOutput(ModelOutput):
"""
Base class for outputs of encoder-decoder generation models using contrastive search. Hidden states and attention
weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
Args:
sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each
generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: tf.Tensor = None
scores: Optional[Tuple[tf.Tensor]] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput]
TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput]
TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput]
TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput]
TFContrastiveSearchOutput = Union[TFContrastiveSearchEncoderDecoderOutput, TFContrastiveSearchDecoderOnlyOutput]
TFGenerateOutput = Union[
TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, TFContrastiveSearchOutput
]
class TFGenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
The class exposes [`~generation_tf_utils.TFGenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_tf_utils.TFGenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation_tf_utils.TFGenerationMixin.contrastive_search`] if
`penalty_alpha>0` and `top_k>1`
- *multinomial sampling* by calling [`~generation_tf_utils.TFGenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_tf_utils.TFGenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
"""
_seed_generator = None
@ -386,6 +471,7 @@ class TFGenerationMixin:
early_stopping=None,
num_beams=None,
temperature=None,
penalty_alpha=None,
top_k=None,
top_p=None,
repetition_penalty=None,
@ -409,10 +495,19 @@ class TFGenerationMixin:
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
) -> Union[TFGenerateOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation_tf_utils.TFGenerationMixin.greedy_search`] if `num_beams=1`
and `do_sample=False`.
- *contrastive search* by calling [`~generation_tf_utils.TFGenerationMixin.contrastive_search`] if
`penalty_alpha>0` and `top_k>1`
- *multinomial sampling* by calling [`~generation_tf_utils.TFGenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_tf_utils.TFGenerationMixin.beam_search`] if `num_beams>1`
and `do_sample=False`.
Adapted in part from [Facebook's XLM beam search
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
@ -447,6 +542,8 @@ class TFGenerationMixin:
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
@ -606,6 +703,7 @@ class TFGenerationMixin:
early_stopping=early_stopping,
num_beams=num_beams,
temperature=temperature,
penalty_alpha=penalty_alpha,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
@ -1374,6 +1472,7 @@ class TFGenerationMixin:
early_stopping=None,
num_beams=None,
temperature=None,
penalty_alpha=None,
top_k=None,
top_p=None,
repetition_penalty=None,
@ -1400,8 +1499,17 @@ class TFGenerationMixin:
**model_kwargs,
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation_tf_utils.TFGenerationMixin.greedy_search`] if `num_beams=1`
and `do_sample=False`.
- *contrastive search* by calling [`~generation_tf_utils.TFGenerationMixin.contrastive_search`] if
`penalty_alpha>0` and `top_k>1`
- *multinomial sampling* by calling [`~generation_tf_utils.TFGenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation_tf_utils.TFGenerationMixin.beam_search`] if `num_beams>1`
and `do_sample=False`.
Adapted in part from [Facebook's XLM beam search
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
@ -1433,6 +1541,8 @@ class TFGenerationMixin:
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
@ -1726,9 +1836,12 @@ class TFGenerationMixin:
# 7. determine generation mode
# TODO(Matt, Joao, Patrick) - add more use cases here
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_contrastive_search_gen_mode = (
top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0
)
is_greedy_gen_mode = not is_contrastive_search_gen_mode and (num_beams == 1) and do_sample is False
is_beam_gen_mode = not is_contrastive_search_gen_mode and (num_beams > 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
@ -1752,7 +1865,7 @@ class TFGenerationMixin:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# 9. run greedy search
# 10. run greedy search
return self.greedy_search(
input_ids,
max_length=max_length,
@ -1763,13 +1876,31 @@ class TFGenerationMixin:
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
elif is_contrastive_search_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search."
)
# 10. run contrastive search
return self.contrastive_search(
input_ids,
top_k=top_k,
penalty_alpha=penalty_alpha,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
input_ids=input_ids,
expand_size=num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
@ -1913,32 +2044,33 @@ class TFGenerationMixin:
@staticmethod
def _expand_inputs_for_generation(
input_ids: tf.Tensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: Optional[tf.Tensor] = None,
encoder_outputs: Optional[ModelOutput] = None,
input_ids: Optional[tf.Tensor] = None,
**model_kwargs,
) -> Tuple[tf.Tensor, Dict[str, Any]]:
expanded_return_idx = tf.reshape(
tf.tile(tf.reshape(tf.range(tf.shape(input_ids)[0]), (-1, 1)), (1, expand_size)), (-1,)
)
input_ids = tf.gather(input_ids, expanded_return_idx, axis=0)
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
if input_ids is not None:
input_ids = tf.repeat(input_ids, expand_size, axis=0)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx, axis=0)
if model_kwargs.get("token_type_ids") is not None:
model_kwargs["token_type_ids"] = tf.repeat(model_kwargs["token_type_ids"], expand_size, axis=0)
if attention_mask is not None:
model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx, axis=0)
if model_kwargs.get("attention_mask") is not None:
model_kwargs["attention_mask"] = tf.repeat(model_kwargs["attention_mask"], expand_size, axis=0)
if model_kwargs.get("decoder_attention_mask") is not None:
model_kwargs["decoder_attention_mask"] = tf.repeat(
model_kwargs["decoder_attention_mask"], expand_size, axis=0
)
if is_encoder_decoder:
encoder_outputs = model_kwargs.get("encoder_outputs")
if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = tf.gather(
encoder_outputs.last_hidden_state, expanded_return_idx, axis=0
)
encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, expand_size, axis=0)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None):
@ -1956,18 +2088,21 @@ class TFGenerationMixin:
return inputs
@staticmethod
def _extract_past_from_model_output(outputs: ModelOutput):
past = None
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
elif "past_buckets_states" in outputs:
past = outputs.past_buckets_states
return past
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
else:
model_kwargs["past"] = None
model_kwargs["past"] = self._extract_past_from_model_output(outputs)
# update attention mask
if not is_encoder_decoder:
@ -2077,13 +2212,8 @@ class TFGenerationMixin:
)
return new_past
if "past_key_values" in model_outputs:
past = model_outputs.past_key_values
elif "mems" in model_outputs:
past = model_outputs.mems
elif "past_buckets_states" in model_outputs:
past = model_outputs.past_buckets_states
else:
past = self._extract_past_from_model_output(model_outputs)
if past is None:
raise ValueError(
f"No known past variable found in model outputs (model outputs keys: {list(model_outputs.keys())})"
)
@ -3192,6 +3322,400 @@ class TFGenerationMixin:
else:
return sequences
def contrastive_search(
self,
input_ids: tf.Tensor,
top_k: Optional[int] = 1,
penalty_alpha: Optional[float] = 0,
logits_processor: Optional[TFLogitsProcessorList] = None,
logits_warper: Optional[TFLogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
**model_kwargs,
) -> Union[TFContrastiveSearchOutput, tf.Tensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
top_k (`int`, *optional*, defaults to 1):
The size of the candidate set that is used to re-rank for contrastive search
penalty_alpha (`float`, *optional*, defaults to 0):
The degeneration penalty for contrastive search; activate when it is larger than 0
logits_processor (`TFLogitsProcessorList`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`TFLogitsProcessorList`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`]
used to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `call` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation_tf_utils.TFContrastiveSearchDecoderOnlyOutput`],
[`~generation_tf_utils.TFContrastiveSearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing
the generated tokens (default behaviour) or a
[`~generation_tf_utils.TFContrastiveySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False`
and `return_dict_in_generate=True` or a [`~generation_tf_utils.TFContrastiveSearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
Examples:
```python
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> model = TFAutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> # set pad_token_id to eos_token_id because OPT does not have a PAD token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="tf")
>>> outputs = model.contrastive_search(**input_ids, penalty_alpha=0.6, top_k=4, max_length=64)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMinds mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it']
```"""
def gather_best_candidate(nested, selected_idx_stacked, batch_axis=0):
"""Gathers the slices indexed by selected_idx_stacked from a potentially nested structure of tensors."""
def gather_fn(tensor):
gathered_tensor = tf.gather(params=tensor, indices=selected_idx_stacked, axis=batch_axis)
return gathered_tensor
return tf.nest.map_structure(gather_fn, nested)
# 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self)
cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# 3. init tensors to use for "xla-compileable" generate function
batch_size, cur_len = shape_list(input_ids)
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
generated = tf.concat([input_ids, input_ids_padding], axis=-1)
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn
def contrastive_search_cond_fn(
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
):
"""state termination condition fn."""
return ~tf.reduce_all(finished_sequences)
# define condition fn
def contrastive_search_body_fn(
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
):
"""state update fn."""
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None:
# prepare inputs
model_inputs = self.prepare_inputs_for_generation(generated[:, :cur_len], **model_kwargs)
model_inputs["use_cache"] = True
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
outputs = self(
**model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if self.config.is_encoder_decoder:
last_hidden_states = outputs.decoder_hidden_states[-1]
else:
last_hidden_states = outputs.hidden_states[-1]
# XLA: last_hidden_states normally grows at each step, but in XLA it is padded so as to be used across
# iterations (with fixed shapes)
if use_xla:
last_hidden_states = tf.pad(last_hidden_states, [[0, 0], [0, max_length - cur_len], [0, 0]])
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :]
if use_xla:
model_kwargs = self._update_model_kwargs_for_xla_generation(
model_outputs=outputs,
model_kwargs=model_kwargs,
cur_len=cur_len,
max_length=max_length,
batch_size=batch_size,
is_encoder_decoder=self.config.is_encoder_decoder,
batch_axis=cache_batch_axis,
)
else:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
past = model_kwargs.get("past")
if past is None:
raise ValueError(
f"{self.__class__.__name__} does not support caching and therefore **can't** be used "
"for contrastive search."
)
elif not isinstance(past[0], (tuple, tf.Tensor)) or past[0][0].shape[0] != batch_size:
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
"used for contrastive search without further modifications."
)
else:
logit_for_next_step = next_step_cached_variables["logit_for_next_step"]
last_hidden_states = next_step_cached_variables["last_hidden_states"]
outputs = next_step_cached_variables["outputs"]
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
logit_for_next_step = logits_processor(generated, logit_for_next_step, cur_len)
logit_for_next_step = logits_warper(generated, logit_for_next_step, cur_len)
next_probs = stable_softmax(logit_for_next_step, axis=-1)
top_k_probs, top_k_ids = tf.math.top_k(next_probs, k=top_k)
# Store scores, attentions and hidden_states when required
if not use_xla and return_dict_in_generate:
if output_scores:
scores.append(outputs.logits[:, -1])
if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder:
decoder_attentions.append(outputs.attentions)
if self.config.is_encoder_decoder:
cross_attentions.append(outputs.cross_attentions)
if output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(outputs.decoder_hidden_states)
elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(outputs.hidden_states)
# Replicates the new past_key_values to match the `top_k` candidates
model_kwargs["past"] = tf.nest.map_structure(
lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past"]
)
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(tf.reshape(top_k_ids, [-1, 1]), **model_kwargs)
next_model_inputs["use_cache"] = True
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
next_past_key_values = self._extract_past_from_model_output(outputs)
logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
context_hidden = tf.repeat(last_hidden_states[:, :cur_len, :], top_k, axis=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
# converts indices to a dimension of top_k to the stacked top_k * batch_size dimension, for indexing
# without a need to reshape on tensors that have these two dimensions stacked
selected_idx_stacked = selected_idx + tf.range(selected_idx.shape[0], dtype=tf.int64) * top_k
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = tf.gather(top_k_ids, selected_idx, axis=1, batch_dims=1)
next_hidden = gather_best_candidate(next_hidden, selected_idx_stacked)
# XLA: last_hidden_states normally grows at each step, but in XLA it is padded so as to be used across
# iterations (with fixed shapes)
if use_xla:
last_hidden_states = dynamic_update_slice(last_hidden_states, next_hidden, [0, cur_len, 0])
else:
last_hidden_states = tf.concat([last_hidden_states, next_hidden], axis=1)
next_decoder_hidden_states = gather_best_candidate(full_hidden_states, selected_idx_stacked)
next_past_key_values = gather_best_candidate(
next_past_key_values, selected_idx_stacked, batch_axis=cache_batch_axis
)
logit_for_next_step = gather_best_candidate(logits, selected_idx_stacked)
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder:
next_step_cross_attentions = ()
next_step_decoder_attentions = ()
if output_attentions:
next_step_cross_attentions = gather_best_candidate(outputs.cross_attentions, selected_idx_stacked)
next_step_decoder_attentions = gather_best_candidate(
outputs.decoder_attentions, selected_idx_stacked
)
outputs = TFSeq2SeqLMOutput(
past_key_values=next_past_key_values,
decoder_hidden_states=next_decoder_hidden_states,
decoder_attentions=next_step_decoder_attentions or None,
cross_attentions=next_step_cross_attentions or None,
)
else:
next_step_attentions = ()
if output_attentions:
next_step_attentions = gather_best_candidate(outputs.attentions, selected_idx_stacked)
outputs = TFCausalLMOutputWithPast(
past_key_values=next_past_key_values,
hidden_states=next_decoder_hidden_states,
attentions=next_step_attentions or None,
)
# contrastive_search main logic end
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update `generated` and `cur_len`
update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1)
generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens)
cur_len += 1
if use_xla:
# NOTE: 1) relative to other generation strategies, contrastive search is always running forward
# passes one step ahead -- hence the `cur_len=cur_len + 1`; 2) the attention mask here is expanded from
# [batch_size, ...] to [batch_size*top_k, ...] -- hence the `batch_size=batch_size * top_k`
model_kwargs = self._update_model_kwargs_for_xla_generation(
model_outputs=outputs,
model_kwargs=model_kwargs,
cur_len=cur_len + 1,
max_length=max_length,
batch_size=batch_size * top_k,
is_encoder_decoder=self.config.is_encoder_decoder,
batch_axis=cache_batch_axis,
)
else:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
next_step_cached_variables = {
"logit_for_next_step": logit_for_next_step,
"last_hidden_states": last_hidden_states,
"outputs": outputs,
}
return generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables = contrastive_search_body_fn(
generated, finished_sequences, cur_len, model_kwargs, None
)
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if contrastive_search_cond_fn(
generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables
):
maximum_iterations = max_length - cur_len
generated, _, cur_len, _, _, = tf.while_loop(
contrastive_search_cond_fn,
contrastive_search_body_fn,
(generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
if not use_xla:
# cut for backward compatibility
generated = generated[:, :cur_len]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
scores = tuple(scores) if scores is not None else None
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFContrastiveSearchEncoderDecoderOutput(
sequences=generated,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return TFContrastiveSearchDecoderOnlyOutput(
sequences=generated,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return generated
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
# create logit penalties for already seen input_ids
@ -3386,3 +3910,26 @@ class BeamHypotheses(object):
cur_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= cur_score
return ret
def _ranking_fast(
context_hidden: tf.Tensor,
next_hidden: tf.Tensor,
next_top_k_probs: tf.Tensor,
alpha: float,
beam_width: int,
) -> tf.Tensor:
"""
Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described
in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each
row in the batch.
"""
norm_context_hidden = context_hidden / tf.norm(context_hidden, axis=2, keepdims=True)
norm_next_hidden = next_hidden / tf.norm(next_hidden, axis=2, keepdims=True)
cosine_matrix = tf.squeeze(tf.linalg.matmul(norm_context_hidden, norm_next_hidden, transpose_b=True), axis=-1)
degeneration_penalty = tf.reduce_max(cosine_matrix, axis=-1)
next_top_k_probs = tf.reshape(next_top_k_probs, shape=[-1])
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
contrastive_score = tf.reshape(contrastive_score, shape=[-1, beam_width])
selected_idx = tf.argmax(contrastive_score, axis=1)
return selected_idx

View File

@ -651,6 +651,7 @@ class GenerationMixin:
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
"""Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
@ -1860,7 +1861,7 @@ class GenerationMixin:
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
>>> # set pad_token_id to eos_token_id because OPT does not have a PAD token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
@ -1916,7 +1917,7 @@ class GenerationMixin:
if this_peer_finished_flag.item() == 0.0:
break
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past") is None:
@ -2014,7 +2015,7 @@ class GenerationMixin:
full_hidden_states = outputs.hidden_states
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)

View File

@ -550,6 +550,100 @@ class TFBartModelIntegrationTest(unittest.TestCase):
def tok(self):
return BartTokenizer.from_pretrained("facebook/bart-large")
@slow
def test_contrastive_search_bart(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="tf"
).input_ids
outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@slow
def test_contrastive_search_bart_xla(self):
article = (
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
" year later, she got married again in Westchester County, but to a different man and without divorcing"
" her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
" once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
" license application, according to court documents. Prosecutors said the marriages were part of an"
" immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
" her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
" arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
" York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
" Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All"
" occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
" married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
" said the immigration scam involved some of her husbands, who filed for permanent residence status"
" shortly after the marriages. Any divorces happened only after such filings were approved. It was"
" unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
" Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
" Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
" native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
" up to four years in prison. Her next court appearance is scheduled for May 18."
)
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
input_ids = bart_tokenizer(
article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="tf"
).input_ids
xla_generate = tf.function(bart_model.generate, jit_compile=True)
# no_repeat_ngram_size set to 0 because it isn't compatible with XLA, but doesn't change the original output
outputs = xla_generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, no_repeat_ngram_size=0)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. "
"Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is "
"accused of being part of an immigration scam to get permanent residency. If convicted, she faces up "
"to four years in"
],
)
@slow
@require_tf

View File

@ -663,3 +663,72 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
output_ids = xla_generate(**input_ids, do_sample=False, num_beams=2)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
self.assertListEqual(output_strings, expected_output_strings)
@slow
def test_contrastive_search_gpt2(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large")
input_ids = gpt2_tokenizer(article, return_tensors="tf")
outputs = gpt2_model.generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)
@slow
def test_contrastive_search_gpt2_xla(self):
article = (
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large")
input_ids = gpt2_tokenizer(article, return_tensors="tf")
xla_generate = tf.function(gpt2_model.generate, jit_compile=True)
outputs = xla_generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256)
generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research "
"laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, "
"United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as "
"Google Now, which helps users find the information they're looking for on the web. But the company "
"is not the only one to collect data on its users. Facebook, for example, has its own facial "
"recognition technology, as well as a database of millions of photos that it uses to personalize its "
"News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates "
"concerned about the company's ability to keep users' information private. In a blog post last "
'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our '
'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with '
'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at '
'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, '
"but said in a statement to The Associated Press that"
],
)

View File

@ -1783,7 +1783,7 @@ class TFModelTesterMixin:
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, num_beams, num_return_sequences, max_length):
def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
@ -1801,9 +1801,9 @@ class TFModelTesterMixin:
else:
raise ValueError("No valid generate input found in inputs_dict")
generated = model.generate(inputs).numpy()
generated = model.generate(inputs, **generate_kwargs).numpy()
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(inputs).numpy()
generated_xla = generate_xla(inputs, **generate_kwargs).numpy()
self.assertListEqual(generated.tolist(), generated_xla.tolist())
for model_class in self.all_generative_model_classes:
@ -1844,6 +1844,19 @@ class TFModelTesterMixin:
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length)
def test_xla_generate_contrastive(self):
"""
Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the
model cache and other outputs, and this test ensures that they are in a valid format that is also supported
by XLA.
Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception
"""
num_beams = 1
num_return_sequences = 1
max_length = 10
self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5)
@slow
def test_xla_generate_slow(self):
"""