From a0f867430347bcf939f71d186409b9ca138c3b34 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 7 Nov 2022 10:54:29 +0000 Subject: [PATCH] Generate: TF contrastive search with XLA support (#20050) * Add contrastive search --- src/transformers/generation_tf_utils.py | 627 +++++++++++++++++++-- src/transformers/generation_utils.py | 7 +- tests/models/bart/test_modeling_tf_bart.py | 94 +++ tests/models/gpt2/test_modeling_tf_gpt2.py | 69 +++ tests/test_modeling_tf_common.py | 19 +- 5 files changed, 770 insertions(+), 46 deletions(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index ac41c65c54..8c52fec623 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -38,6 +38,7 @@ from .generation_tf_logits_process import ( TFTopKLogitsWarper, TFTopPLogitsWarper, ) +from .modeling_tf_outputs import TFCausalLMOutputWithPast, TFSeq2SeqLMOutput from .models.auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, @@ -345,15 +346,99 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput): decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None +@dataclass +class TFContrastiveSearchDecoderOnlyOutput(ModelOutput): + """ + Base class for outputs of decoder-only generation models using contrastive search. + + + Args: + sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each + generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: tf.Tensor = None + scores: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None + + +@dataclass +class TFContrastiveSearchEncoderDecoderOutput(ModelOutput): + """ + Base class for outputs of encoder-decoder generation models using contrastive search. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + + + Args: + sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each + generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: tf.Tensor = None + scores: Optional[Tuple[tf.Tensor]] = None + encoder_attentions: Optional[Tuple[tf.Tensor]] = None + encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None + + TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput] TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput] TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput] TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput] +TFContrastiveSearchOutput = Union[TFContrastiveSearchEncoderDecoderOutput, TFContrastiveSearchDecoderOnlyOutput] +TFGenerateOutput = Union[ + TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, TFContrastiveSearchOutput +] class TFGenerationMixin: """ A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`]. + + The class exposes [`~generation_tf_utils.TFGenerationMixin.generate`], which can be used for: + - *greedy decoding* by calling [`~generation_tf_utils.TFGenerationMixin.greedy_search`] if `num_beams=1` and + `do_sample=False`. + - *contrastive search* by calling [`~generation_tf_utils.TFGenerationMixin.contrastive_search`] if + `penalty_alpha>0` and `top_k>1` + - *multinomial sampling* by calling [`~generation_tf_utils.TFGenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_tf_utils.TFGenerationMixin.beam_search`] if `num_beams>1` and + `do_sample=False`. """ _seed_generator = None @@ -386,6 +471,7 @@ class TFGenerationMixin: early_stopping=None, num_beams=None, temperature=None, + penalty_alpha=None, top_k=None, top_p=None, repetition_penalty=None, @@ -409,10 +495,19 @@ class TFGenerationMixin: begin_suppress_tokens: Optional[List[int]] = None, forced_decoder_ids: Optional[List[List[int]]] = None, **model_kwargs, - ) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: + ) -> Union[TFGenerateOutput, tf.Tensor]: r""" - Generates sequences for models with a language modeling head. The method currently supports greedy decoding, - beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation_tf_utils.TFGenerationMixin.greedy_search`] if `num_beams=1` + and `do_sample=False`. + - *contrastive search* by calling [`~generation_tf_utils.TFGenerationMixin.contrastive_search`] if + `penalty_alpha>0` and `top_k>1` + - *multinomial sampling* by calling [`~generation_tf_utils.TFGenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_tf_utils.TFGenerationMixin.beam_search`] if `num_beams>1` + and `do_sample=False`. Adapted in part from [Facebook's XLM beam search code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529). @@ -447,6 +542,8 @@ class TFGenerationMixin: Number of beams for beam search. 1 means no beam search. temperature (`float`, *optional*, defaults to 1.0): The value used to module the next token probabilities. + penalty_alpha (`float`, *optional*): + The values balance the model confidence and the degeneration penalty in contrastive search decoding. top_k (`int`, *optional*, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*, defaults to 1.0): @@ -606,6 +703,7 @@ class TFGenerationMixin: early_stopping=early_stopping, num_beams=num_beams, temperature=temperature, + penalty_alpha=penalty_alpha, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, @@ -1374,6 +1472,7 @@ class TFGenerationMixin: early_stopping=None, num_beams=None, temperature=None, + penalty_alpha=None, top_k=None, top_p=None, repetition_penalty=None, @@ -1400,8 +1499,17 @@ class TFGenerationMixin: **model_kwargs, ) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]: r""" - Generates sequences for models with a language modeling head. The method currently supports greedy decoding, - beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. + Generates sequences of token ids for models with a language modeling head. The method supports the following + generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + + - *greedy decoding* by calling [`~generation_tf_utils.TFGenerationMixin.greedy_search`] if `num_beams=1` + and `do_sample=False`. + - *contrastive search* by calling [`~generation_tf_utils.TFGenerationMixin.contrastive_search`] if + `penalty_alpha>0` and `top_k>1` + - *multinomial sampling* by calling [`~generation_tf_utils.TFGenerationMixin.sample`] if `num_beams=1` and + `do_sample=True`. + - *beam-search decoding* by calling [`~generation_tf_utils.TFGenerationMixin.beam_search`] if `num_beams>1` + and `do_sample=False`. Adapted in part from [Facebook's XLM beam search code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529). @@ -1433,6 +1541,8 @@ class TFGenerationMixin: Number of beams for beam search. 1 means no beam search. temperature (`float`, *optional*, defaults to 1.0): The value used to module the next token probabilities. + penalty_alpha (`float`, *optional*): + The values balance the model confidence and the degeneration penalty in contrastive search decoding. top_k (`int`, *optional*, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*, defaults to 1.0): @@ -1726,9 +1836,12 @@ class TFGenerationMixin: # 7. determine generation mode # TODO(Matt, Joao, Patrick) - add more use cases here - is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_contrastive_search_gen_mode = ( + top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0 + ) + is_greedy_gen_mode = not is_contrastive_search_gen_mode and (num_beams == 1) and do_sample is False + is_beam_gen_mode = not is_contrastive_search_gen_mode and (num_beams > 1) and do_sample is False is_sample_gen_mode = (num_beams == 1) and do_sample is True - is_beam_gen_mode = (num_beams > 1) and do_sample is False # 8. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( @@ -1752,7 +1865,7 @@ class TFGenerationMixin: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) - # 9. run greedy search + # 10. run greedy search return self.greedy_search( input_ids, max_length=max_length, @@ -1763,13 +1876,31 @@ class TFGenerationMixin: return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) + elif is_contrastive_search_gen_mode: + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search." + ) + # 10. run contrastive search + return self.contrastive_search( + input_ids, + top_k=top_k, + penalty_alpha=penalty_alpha, + logits_processor=logits_processor, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + **model_kwargs, + ) elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids, + input_ids=input_ids, expand_size=num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, @@ -1913,32 +2044,33 @@ class TFGenerationMixin: @staticmethod def _expand_inputs_for_generation( - input_ids: tf.Tensor, expand_size: int = 1, is_encoder_decoder: bool = False, - attention_mask: Optional[tf.Tensor] = None, - encoder_outputs: Optional[ModelOutput] = None, + input_ids: Optional[tf.Tensor] = None, **model_kwargs, ) -> Tuple[tf.Tensor, Dict[str, Any]]: - expanded_return_idx = tf.reshape( - tf.tile(tf.reshape(tf.range(tf.shape(input_ids)[0]), (-1, 1)), (1, expand_size)), (-1,) - ) - input_ids = tf.gather(input_ids, expanded_return_idx, axis=0) + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + if input_ids is not None: + input_ids = tf.repeat(input_ids, expand_size, axis=0) - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = tf.gather(token_type_ids, expanded_return_idx, axis=0) + if model_kwargs.get("token_type_ids") is not None: + model_kwargs["token_type_ids"] = tf.repeat(model_kwargs["token_type_ids"], expand_size, axis=0) - if attention_mask is not None: - model_kwargs["attention_mask"] = tf.gather(attention_mask, expanded_return_idx, axis=0) + if model_kwargs.get("attention_mask") is not None: + model_kwargs["attention_mask"] = tf.repeat(model_kwargs["attention_mask"], expand_size, axis=0) + + if model_kwargs.get("decoder_attention_mask") is not None: + model_kwargs["decoder_attention_mask"] = tf.repeat( + model_kwargs["decoder_attention_mask"], expand_size, axis=0 + ) if is_encoder_decoder: + encoder_outputs = model_kwargs.get("encoder_outputs") if encoder_outputs is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - encoder_outputs["last_hidden_state"] = tf.gather( - encoder_outputs.last_hidden_state, expanded_return_idx, axis=0 - ) + encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, expand_size, axis=0) model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs def _prepare_model_inputs(self, inputs: Optional[tf.Tensor] = None, bos_token_id: Optional[int] = None): @@ -1956,18 +2088,21 @@ class TFGenerationMixin: return inputs @staticmethod + def _extract_past_from_model_output(outputs: ModelOutput): + past = None + if "past_key_values" in outputs: + past = outputs.past_key_values + elif "mems" in outputs: + past = outputs.mems + elif "past_buckets_states" in outputs: + past = outputs.past_buckets_states + return past + def _update_model_kwargs_for_generation( - outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False ) -> Dict[str, Any]: # update past - if "past_key_values" in outputs: - model_kwargs["past"] = outputs.past_key_values - elif "mems" in outputs: - model_kwargs["past"] = outputs.mems - elif "past_buckets_states" in outputs: - model_kwargs["past"] = outputs.past_buckets_states - else: - model_kwargs["past"] = None + model_kwargs["past"] = self._extract_past_from_model_output(outputs) # update attention mask if not is_encoder_decoder: @@ -2077,13 +2212,8 @@ class TFGenerationMixin: ) return new_past - if "past_key_values" in model_outputs: - past = model_outputs.past_key_values - elif "mems" in model_outputs: - past = model_outputs.mems - elif "past_buckets_states" in model_outputs: - past = model_outputs.past_buckets_states - else: + past = self._extract_past_from_model_output(model_outputs) + if past is None: raise ValueError( f"No known past variable found in model outputs (model outputs keys: {list(model_outputs.keys())})" ) @@ -3192,6 +3322,400 @@ class TFGenerationMixin: else: return sequences + def contrastive_search( + self, + input_ids: tf.Tensor, + top_k: Optional[int] = 1, + penalty_alpha: Optional[float] = 0, + logits_processor: Optional[TFLogitsProcessorList] = None, + logits_warper: Optional[TFLogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[TFContrastiveSearchOutput, tf.Tensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **contrastive search** and can + be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + top_k (`int`, *optional*, defaults to 1): + The size of the candidate set that is used to re-rank for contrastive search + penalty_alpha (`float`, *optional*, defaults to 0): + The degeneration penalty for contrastive search; activate when it is larger than 0 + logits_processor (`TFLogitsProcessorList`, *optional*): + An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + logits_warper (`TFLogitsProcessorList`, *optional*): + An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsWarper`] + used to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `call` function of the model. If + model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation_tf_utils.TFContrastiveSearchDecoderOnlyOutput`], + [`~generation_tf_utils.TFContrastiveSearchEncoderDecoderOutput`] or `tf.Tensor`: A `tf.Tensor` containing + the generated tokens (default behaviour) or a + [`~generation_tf_utils.TFContrastiveySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` + and `return_dict_in_generate=True` or a [`~generation_tf_utils.TFContrastiveSearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + ```python + >>> from transformers import AutoTokenizer, TFAutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + >>> model = TFAutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> # set pad_token_id to eos_token_id because OPT does not have a PAD token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> input_prompt = "DeepMind Company is" + >>> input_ids = tokenizer(input_prompt, return_tensors="tf") + >>> outputs = model.contrastive_search(**input_ids, penalty_alpha=0.6, top_k=4, max_length=64) + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['DeepMind Company is a company that focuses on the development and commercialization of artificial intelligence (AI). DeepMindā€™s mission is to help people understand and solve problems that are difficult to solve in the world today.\n\nIn this post, we talk about the benefits of deep learning in business and how it'] + ```""" + + def gather_best_candidate(nested, selected_idx_stacked, batch_axis=0): + """Gathers the slices indexed by selected_idx_stacked from a potentially nested structure of tensors.""" + + def gather_fn(tensor): + gathered_tensor = tf.gather(params=tensor, indices=selected_idx_stacked, axis=batch_axis) + return gathered_tensor + + return tf.nest.map_structure(gather_fn, nested) + + # 1. init greedy_search values + logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + output_scores = output_scores if output_scores is not None else self.config.output_scores + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate + ) + use_xla = not tf.executing_eagerly() + # TODO (Joao): fix cache format or find programatic way to detect cache index + # GPT2 and other models has a slightly different cache structure, with a different batch axis + model_name = str(self.decoder) if "EncoderDecoder" in str(self) else str(self) + cache_batch_axis = 1 if any([model_prefix in model_name for model_prefix in ("TFGPT2", "TFCTRL")]) else 0 + + # 2. init `attentions`, `hidden_states`, and `scores` tuples + scores = [] if (return_dict_in_generate and output_scores) else None + decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None + cross_attentions = [] if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None + + # 3. init tensors to use for "xla-compileable" generate function + batch_size, cur_len = shape_list(input_ids) + + # initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences` + input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0) + generated = tf.concat([input_ids, input_ids_padding], axis=-1) + finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) + + # 4. define "xla-compile-able" stop-condition and auto-regressive function + # define condition fn + def contrastive_search_cond_fn( + generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables + ): + """state termination condition fn.""" + return ~tf.reduce_all(finished_sequences) + + # define condition fn + def contrastive_search_body_fn( + generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables + ): + """state update fn.""" + + # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; + # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step + if model_kwargs.get("past") is None: + + # prepare inputs + model_inputs = self.prepare_inputs_for_generation(generated[:, :cur_len], **model_kwargs) + model_inputs["use_cache"] = True + + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save + # the `encoder_outputs` + outputs = self( + **model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + ) + + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with + # previous tokens) + if self.config.is_encoder_decoder: + last_hidden_states = outputs.decoder_hidden_states[-1] + else: + last_hidden_states = outputs.hidden_states[-1] + + # XLA: last_hidden_states normally grows at each step, but in XLA it is padded so as to be used across + # iterations (with fixed shapes) + if use_xla: + last_hidden_states = tf.pad(last_hidden_states, [[0, 0], [0, max_length - cur_len], [0, 0]]) + + # next logit for contrastive search to select top-k candidate tokens + logit_for_next_step = outputs.logits[:, -1, :] + + if use_xla: + model_kwargs = self._update_model_kwargs_for_xla_generation( + model_outputs=outputs, + model_kwargs=model_kwargs, + cur_len=cur_len, + max_length=max_length, + batch_size=batch_size, + is_encoder_decoder=self.config.is_encoder_decoder, + batch_axis=cache_batch_axis, + ) + else: + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # Expands model inputs top_k times, for batched forward passes (akin to beam search). + _, model_kwargs = self._expand_inputs_for_generation( + expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + + past = model_kwargs.get("past") + if past is None: + raise ValueError( + f"{self.__class__.__name__} does not support caching and therefore **can't** be used " + "for contrastive search." + ) + elif not isinstance(past[0], (tuple, tf.Tensor)) or past[0][0].shape[0] != batch_size: + raise ValueError( + f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be " + "used for contrastive search without further modifications." + ) + else: + logit_for_next_step = next_step_cached_variables["logit_for_next_step"] + last_hidden_states = next_step_cached_variables["last_hidden_states"] + outputs = next_step_cached_variables["outputs"] + + # contrastive_search main logic start: + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by + # degeneration penalty + + logit_for_next_step = logits_processor(generated, logit_for_next_step, cur_len) + logit_for_next_step = logits_warper(generated, logit_for_next_step, cur_len) + next_probs = stable_softmax(logit_for_next_step, axis=-1) + top_k_probs, top_k_ids = tf.math.top_k(next_probs, k=top_k) + + # Store scores, attentions and hidden_states when required + if not use_xla and return_dict_in_generate: + if output_scores: + scores.append(outputs.logits[:, -1]) + if output_attentions and self.config.is_encoder_decoder: + decoder_attentions.append(outputs.decoder_attentions) + elif output_attentions and not self.config.is_encoder_decoder: + decoder_attentions.append(outputs.attentions) + if self.config.is_encoder_decoder: + cross_attentions.append(outputs.cross_attentions) + + if output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(outputs.decoder_hidden_states) + elif output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(outputs.hidden_states) + + # Replicates the new past_key_values to match the `top_k` candidates + model_kwargs["past"] = tf.nest.map_structure( + lambda tensor: tf.repeat(tensor, top_k, axis=cache_batch_axis), model_kwargs["past"] + ) + + # compute the candidate tokens by the language model and collects their hidden_states + next_model_inputs = self.prepare_inputs_for_generation(tf.reshape(top_k_ids, [-1, 1]), **model_kwargs) + next_model_inputs["use_cache"] = True + outputs = self( + **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions + ) + next_past_key_values = self._extract_past_from_model_output(outputs) + + logits = outputs.logits[:, -1, :] + # name is different for encoder-decoder and decoder-only models + if self.config.is_encoder_decoder: + next_hidden = outputs.decoder_hidden_states[-1] + full_hidden_states = outputs.decoder_hidden_states + else: + next_hidden = outputs.hidden_states[-1] + full_hidden_states = outputs.hidden_states + context_hidden = tf.repeat(last_hidden_states[:, :cur_len, :], top_k, axis=0) + + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence + selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) + + # converts indices to a dimension of top_k to the stacked top_k * batch_size dimension, for indexing + # without a need to reshape on tensors that have these two dimensions stacked + selected_idx_stacked = selected_idx + tf.range(selected_idx.shape[0], dtype=tf.int64) * top_k + + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing + # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores + # (model confidence minus degeneration penalty); (6) decoder hidden_states + next_tokens = tf.gather(top_k_ids, selected_idx, axis=1, batch_dims=1) + next_hidden = gather_best_candidate(next_hidden, selected_idx_stacked) + + # XLA: last_hidden_states normally grows at each step, but in XLA it is padded so as to be used across + # iterations (with fixed shapes) + if use_xla: + last_hidden_states = dynamic_update_slice(last_hidden_states, next_hidden, [0, cur_len, 0]) + else: + last_hidden_states = tf.concat([last_hidden_states, next_hidden], axis=1) + + next_decoder_hidden_states = gather_best_candidate(full_hidden_states, selected_idx_stacked) + next_past_key_values = gather_best_candidate( + next_past_key_values, selected_idx_stacked, batch_axis=cache_batch_axis + ) + logit_for_next_step = gather_best_candidate(logits, selected_idx_stacked) + + # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration + if self.config.is_encoder_decoder: + next_step_cross_attentions = () + next_step_decoder_attentions = () + if output_attentions: + next_step_cross_attentions = gather_best_candidate(outputs.cross_attentions, selected_idx_stacked) + next_step_decoder_attentions = gather_best_candidate( + outputs.decoder_attentions, selected_idx_stacked + ) + outputs = TFSeq2SeqLMOutput( + past_key_values=next_past_key_values, + decoder_hidden_states=next_decoder_hidden_states, + decoder_attentions=next_step_decoder_attentions or None, + cross_attentions=next_step_cross_attentions or None, + ) + else: + next_step_attentions = () + if output_attentions: + next_step_attentions = gather_best_candidate(outputs.attentions, selected_idx_stacked) + outputs = TFCausalLMOutputWithPast( + past_key_values=next_past_key_values, + hidden_states=next_decoder_hidden_states, + attentions=next_step_attentions or None, + ) + # contrastive_search main logic end + + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32) + next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) + finished_sequences = finished_sequences | (next_tokens == eos_token_id) + + # update `generated` and `cur_len` + update_indices = tf.stack([tf.range(batch_size), tf.broadcast_to(cur_len, [batch_size])], axis=-1) + generated = tf.tensor_scatter_nd_update(tensor=generated, indices=update_indices, updates=next_tokens) + cur_len += 1 + + if use_xla: + # NOTE: 1) relative to other generation strategies, contrastive search is always running forward + # passes one step ahead -- hence the `cur_len=cur_len + 1`; 2) the attention mask here is expanded from + # [batch_size, ...] to [batch_size*top_k, ...] -- hence the `batch_size=batch_size * top_k` + model_kwargs = self._update_model_kwargs_for_xla_generation( + model_outputs=outputs, + model_kwargs=model_kwargs, + cur_len=cur_len + 1, + max_length=max_length, + batch_size=batch_size * top_k, + is_encoder_decoder=self.config.is_encoder_decoder, + batch_axis=cache_batch_axis, + ) + else: + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + next_step_cached_variables = { + "logit_for_next_step": logit_for_next_step, + "last_hidden_states": last_hidden_states, + "outputs": outputs, + } + return generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables + + # 5. run generation + # 1st generation step has to be run before to initialize `past` + generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables = contrastive_search_body_fn( + generated, finished_sequences, cur_len, model_kwargs, None + ) + + # 2-to-n generation steps can then be run in autoregressive fashion + # only in case 1st generation step does NOT yield EOS token though + if contrastive_search_cond_fn( + generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables + ): + maximum_iterations = max_length - cur_len + generated, _, cur_len, _, _, = tf.while_loop( + contrastive_search_cond_fn, + contrastive_search_body_fn, + (generated, finished_sequences, cur_len, model_kwargs, next_step_cached_variables), + maximum_iterations=maximum_iterations, + ) + + # 6. prepare outputs + if not use_xla: + # cut for backward compatibility + generated = generated[:, :cur_len] + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + # if model is an encoder-decoder, retrieve encoder attention weights + # and hidden states + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + scores = tuple(scores) if scores is not None else None + decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None + cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None + decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None + + return TFContrastiveSearchEncoderDecoderOutput( + sequences=generated, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return TFContrastiveSearchDecoderOnlyOutput( + sequences=generated, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return generated + def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): # create logit penalties for already seen input_ids @@ -3386,3 +3910,26 @@ class BeamHypotheses(object): cur_score = best_sum_logprobs / cur_len**self.length_penalty ret = self.worst_score >= cur_score return ret + + +def _ranking_fast( + context_hidden: tf.Tensor, + next_hidden: tf.Tensor, + next_top_k_probs: tf.Tensor, + alpha: float, + beam_width: int, +) -> tf.Tensor: + """ + Reranks the top_k candidates based on a degeneration penalty (cosine similarity with previous tokens), as described + in the paper "A Contrastive Framework for Neural Text Generation". Returns the index of the best candidate for each + row in the batch. + """ + norm_context_hidden = context_hidden / tf.norm(context_hidden, axis=2, keepdims=True) + norm_next_hidden = next_hidden / tf.norm(next_hidden, axis=2, keepdims=True) + cosine_matrix = tf.squeeze(tf.linalg.matmul(norm_context_hidden, norm_next_hidden, transpose_b=True), axis=-1) + degeneration_penalty = tf.reduce_max(cosine_matrix, axis=-1) + next_top_k_probs = tf.reshape(next_top_k_probs, shape=[-1]) + contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty + contrastive_score = tf.reshape(contrastive_score, shape=[-1, beam_width]) + selected_idx = tf.argmax(contrastive_score, axis=1) + return selected_idx diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 72d310cd01..24ddf913ba 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -651,6 +651,7 @@ class GenerationMixin: input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) @@ -1860,7 +1861,7 @@ class GenerationMixin: >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") >>> model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") - >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token + >>> # set pad_token_id to eos_token_id because OPT does not have a PAD token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "DeepMind Company is" >>> input_ids = tokenizer(input_prompt, return_tensors="pt") @@ -1916,7 +1917,7 @@ class GenerationMixin: if this_peer_finished_flag.item() == 0.0: break - # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; + # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past") is None: @@ -2014,7 +2015,7 @@ class GenerationMixin: full_hidden_states = outputs.hidden_states context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) - # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the + # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # model confidence selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) diff --git a/tests/models/bart/test_modeling_tf_bart.py b/tests/models/bart/test_modeling_tf_bart.py index f47824fc08..e4c4f43c42 100644 --- a/tests/models/bart/test_modeling_tf_bart.py +++ b/tests/models/bart/test_modeling_tf_bart.py @@ -550,6 +550,100 @@ class TFBartModelIntegrationTest(unittest.TestCase): def tok(self): return BartTokenizer.from_pretrained("facebook/bart-large") + @slow + def test_contrastive_search_bart(self): + article = ( + " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + bart_model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + input_ids = bart_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="tf" + ).input_ids + + outputs = bart_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) + generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. " + "Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is " + "accused of being part of an immigration scam to get permanent residency. If convicted, she faces up " + "to four years in" + ], + ) + + @slow + def test_contrastive_search_bart_xla(self): + article = ( + " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A" + " year later, she got married again in Westchester County, but to a different man and without divorcing" + " her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos" + ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married' + " once more, this time in the Bronx. In an application for a marriage license, she stated it was her" + ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false' + ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage' + " license application, according to court documents. Prosecutors said the marriages were part of an" + " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to" + " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was" + " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New" + " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total," + " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All" + " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be" + " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors" + " said the immigration scam involved some of her husbands, who filed for permanent residence status" + " shortly after the marriages. Any divorces happened only after such filings were approved. It was" + " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District" + " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's" + ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,' + " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his" + " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces" + " up to four years in prison. Her next court appearance is scheduled for May 18." + ) + bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") + bart_model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") + input_ids = bart_tokenizer( + article, add_special_tokens=False, truncation=True, max_length=512, return_tensors="tf" + ).input_ids + + xla_generate = tf.function(bart_model.generate, jit_compile=True) + # no_repeat_ngram_size set to 0 because it isn't compatible with XLA, but doesn't change the original output + outputs = xla_generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64, no_repeat_ngram_size=0) + generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "Liana Barrientos, 39, pleaded not guilty to charges related to false marriage statements. " + "Prosecutors say she married at least 10 times, sometimes within two weeks of each other. She is " + "accused of being part of an immigration scam to get permanent residency. If convicted, she faces up " + "to four years in" + ], + ) + @slow @require_tf diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index d97a2b3ed9..64cbea4de9 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -663,3 +663,72 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): output_ids = xla_generate(**input_ids, do_sample=False, num_beams=2) output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) self.assertListEqual(output_strings, expected_output_strings) + + @slow + def test_contrastive_search_gpt2(self): + article = ( + "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based" + ) + + gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") + gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large") + input_ids = gpt2_tokenizer(article, return_tensors="tf") + + outputs = gpt2_model.generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + + generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, " + "United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as " + "Google Now, which helps users find the information they're looking for on the web. But the company " + "is not the only one to collect data on its users. Facebook, for example, has its own facial " + "recognition technology, as well as a database of millions of photos that it uses to personalize its " + "News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates " + "concerned about the company's ability to keep users' information private. In a blog post last " + 'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our ' + 'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with ' + 'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at ' + 'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, ' + "but said in a statement to The Associated Press that" + ], + ) + + @slow + def test_contrastive_search_gpt2_xla(self): + article = ( + "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based" + ) + + gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large") + gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2-large") + input_ids = gpt2_tokenizer(article, return_tensors="tf") + + xla_generate = tf.function(gpt2_model.generate, jit_compile=True) + outputs = xla_generate(**input_ids, penalty_alpha=0.6, top_k=4, max_length=256) + + generated_text = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual( + generated_text, + [ + "DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research " + "laboratory founded in 2010. DeepMind was acquired by Google in 2014. The company is based in London, " + "United Kingdom\n\nGoogle has a lot of data on its users and uses it to improve its products, such as " + "Google Now, which helps users find the information they're looking for on the web. But the company " + "is not the only one to collect data on its users. Facebook, for example, has its own facial " + "recognition technology, as well as a database of millions of photos that it uses to personalize its " + "News Feed.\n\nFacebook's use of data is a hot topic in the tech industry, with privacy advocates " + "concerned about the company's ability to keep users' information private. In a blog post last " + 'year, Facebook CEO Mark Zuckerberg said his company would "do our best to be transparent about our ' + 'data use and how we use it."\n\n"We have made it clear that we do not sell or share your data with ' + 'third parties," Zuckerberg wrote. "If you have questions or concerns, please reach out to us at ' + 'privacy@facebook.com."\n\nGoogle declined to comment on the privacy implications of its use of data, ' + "but said in a statement to The Associated Press that" + ], + ) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 9e9e4d9930..a482328189 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1783,7 +1783,7 @@ class TFModelTesterMixin: model.compile(optimizer="sgd", run_eagerly=True) model.train_on_batch(test_batch, test_batch_labels) - def _test_xla_generate(self, num_beams, num_return_sequences, max_length): + def _test_xla_generate(self, num_beams, num_return_sequences, max_length, **generate_kwargs): def _generate_and_check_results(model, config, inputs_dict): if "input_ids" in inputs_dict: inputs = inputs_dict["input_ids"] @@ -1801,9 +1801,9 @@ class TFModelTesterMixin: else: raise ValueError("No valid generate input found in inputs_dict") - generated = model.generate(inputs).numpy() + generated = model.generate(inputs, **generate_kwargs).numpy() generate_xla = tf.function(model.generate, jit_compile=True) - generated_xla = generate_xla(inputs).numpy() + generated_xla = generate_xla(inputs, **generate_kwargs).numpy() self.assertListEqual(generated.tolist(), generated_xla.tolist()) for model_class in self.all_generative_model_classes: @@ -1844,6 +1844,19 @@ class TFModelTesterMixin: max_length = 10 self._test_xla_generate(num_beams, num_return_sequences, max_length) + def test_xla_generate_contrastive(self): + """ + Similar to `test_xla_generate_fast`, but for contrastive search -- contrastive search directly manipulates the + model cache and other outputs, and this test ensures that they are in a valid format that is also supported + by XLA. + + Either the model supports XLA generation and passes the inner test, or it raises an appropriate exception + """ + num_beams = 1 + num_return_sequences = 1 + max_length = 10 + self._test_xla_generate(num_beams, num_return_sequences, max_length, penalty_alpha=0.5, top_k=5) + @slow def test_xla_generate_slow(self): """