From cd4c5c90605b2e23879fcca484f7079b0fc0c361 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 15 Mar 2022 13:19:20 +0000 Subject: [PATCH] TF XLA greedy generation (#15786) * First attempt at TF XLA generation * Fix comments * Update XLA greedy generate with direct XLA calls * Support attention mask, prepare_inputs_for_generation no longer hardcoded for greedy * Handle position_ids correctly * make xla generate work for non xla case * force using xla generate * refactor * more fixes * finish cleaning * finish * finish * clean gpt2 tests * add gpt2 tests * correct more cases * up * finish * finish * more fixes * flake 8 stuff * final rag fix * Update src/transformers/models/rag/modeling_tf_rag.py * finish t5 as well * finish * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen --- .../generation_tf_logits_process.py | 1 - src/transformers/generation_tf_utils.py | 199 ++++++++++++------ .../models/gpt2/modeling_tf_gpt2.py | 80 ++++++- .../models/rag/modeling_tf_rag.py | 8 +- src/transformers/models/t5/modeling_tf_t5.py | 68 ++++++ tests/gpt2/test_modeling_gpt2.py | 31 +-- tests/gpt2/test_modeling_tf_gpt2.py | 34 +++ tests/t5/test_modeling_tf_t5.py | 42 ++++ 8 files changed, 370 insertions(+), 93 deletions(-) diff --git a/src/transformers/generation_tf_logits_process.py b/src/transformers/generation_tf_logits_process.py index 098d76ef27..271957c1ac 100644 --- a/src/transformers/generation_tf_logits_process.py +++ b/src/transformers/generation_tf_logits_process.py @@ -260,7 +260,6 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor): return tf.convert_to_tensor(token_penalties, dtype=tf.float32) def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor: - score_penalties = self._create_score_penalties(input_ids, scores) scores = tf.math.multiply(scores, score_penalties) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 247467702e..2a5234004a 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -1484,9 +1484,12 @@ class TFGenerationMixin: batch_size = input_ids.shape[0] # 3. Prepare other model kwargs - model_kwargs["output_attentions"] = output_attentions - model_kwargs["output_hidden_states"] = output_hidden_states - model_kwargs["use_cache"] = use_cache + if output_attentions is not None: + model_kwargs["output_attentions"] = output_attentions + if output_hidden_states is not None: + model_kwargs["output_hidden_states"] = output_hidden_states + if use_cache is not None: + model_kwargs["use_cache"] = use_cache requires_attention_mask = "encoder_outputs" not in model_kwargs @@ -1533,7 +1536,6 @@ class TFGenerationMixin: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) - # 8. run greedy search return self.greedy_search( input_ids, @@ -1545,7 +1547,6 @@ class TFGenerationMixin: return_dict_in_generate=return_dict_in_generate, **model_kwargs, ) - elif is_sample_gen_mode: # 8. prepare logits warper logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) @@ -1571,15 +1572,13 @@ class TFGenerationMixin: **model_kwargs, ) - # TODO(Matt, Joao, Patrick) - add more sub-generation methods here - def _prepare_attention_mask_for_generation( self, input_ids: tf.Tensor, pad_token_id: int, ) -> tf.Tensor: # prepare `attention_mask` if not passed - if (pad_token_id is not None) and (pad_token_id in input_ids.numpy()): + if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id): return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32) else: return tf.ones(input_ids.shape[:2], dtype=tf.int32) @@ -1717,6 +1716,14 @@ class TFGenerationMixin: return model_kwargs + def _update_model_kwargs_for_xla_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], current_pos: tf.Tensor, max_length: int + ) -> Dict[str, Any]: + raise NotImplementedError( + f"{self.__class__} is not compileable with XLA at the moment. You should implement a " + "`_update_model_kwargs_for_xla_generation` in the respective modeling file for XLA-compatible generation." + ) + def _get_logits_warper( self, top_k: Optional[int] = None, @@ -1773,7 +1780,7 @@ class TFGenerationMixin: processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if bad_words_ids is not None: processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) - if min_length is not None and eos_token_id is not None and min_length > -1: + if min_length is not None and eos_token_id is not None and min_length > 0: processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id)) return processors @@ -1858,7 +1865,8 @@ class TFGenerationMixin: >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) ```""" - # init values + + # 1. init greedy_search values logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id @@ -1871,94 +1879,153 @@ class TFGenerationMixin: return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) + use_xla = not tf.executing_eagerly() - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + # 2. init `attentions`, `hidden_states`, and `scores` tuples + scores = [] if (return_dict_in_generate and output_scores) else None + decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None + cross_attentions = [] if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) + # 3. init tensors to use for "xla-compileable" generate function + # define bsz, seq_length + batch_size, seq_length = input_ids.shape - # keep track of which sequences are already finished - unfinished_sequences = tf.ones_like(input_ids[:, 0]) - cur_len = input_ids.shape[-1] + # initialize `generated`, `finished_sequences`, and `current_pos` + generated = tf.TensorArray( + element_shape=(batch_size,), + dtype=tf.int32, + dynamic_size=False, + size=max_length, + clear_after_read=False, + ) - while cur_len < max_length: - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # write prompt to generated + for i in range(seq_length): + generated = generated.write(i, input_ids[:, i]) - # forward pass to get next token + finished_sequences = tf.zeros((batch_size,), dtype=tf.bool) + current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length + + # 4. define "xla-compile-able" stop-condition and auto-regressive function + # define condition fn + def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): + """state termination condition fn.""" + return ~tf.reduce_all(finished_sequences) + + # define condition fn + def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): + """state update fn.""" + # TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`. + model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs) + # forward pass to get next token logits outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) - - next_token_logits = outputs.logits[:, -1, :] + next_token_logits = outputs.logits[:, -1] # Store scores, attentions and hidden_states when required - if return_dict_in_generate: + if not use_xla and return_dict_in_generate: if output_scores: - scores += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) + scores.append(next_token_logits) + if output_attentions and self.config.is_encoder_decoder: + decoder_attentions.append(outputs.decoder_attentions) + elif output_attentions and not self.config.is_encoder_decoder: + decoder_attentions.append(outputs.attentions) if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + cross_attentions.append(outputs.cross_attentions) - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) + if output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(outputs.decoder_hidden_states) + elif output_hidden_states and self.config.is_encoder_decoder: + decoder_hidden_states.append(outputs.hidden_states) # pre-process distribution + # TODO(pvp, joao, matt) - all the logits processors need to be adapted + # to be XLA compatible + input_ids = None + if not use_xla: + input_ids = tf.reshape(generated.concat(), (-1, batch_size)) + input_ids = tf.transpose(input_ids[: current_pos[0]]) next_tokens_scores = logits_processor(input_ids, next_token_logits) # argmax - next_tokens = tf.cast(tf.argmax(next_tokens_scores, axis=-1), tf.int32) + next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) - # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32) + next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq) + finished_sequences = finished_sequences | (next_tokens == eos_token_id) - # update generated ids, model inputs, and length for next step - input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - cur_len = cur_len + 1 + # update `generated` and `current_pos` + generated = generated.write(current_pos[0], next_tokens) + next_tokens = next_tokens[:, None] + current_pos += 1 - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id is not None: - eos_in_sents = next_tokens == eos_token_id - # if sentence is unfinished and the token to add is eos - is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply( - unfinished_sequences, tf.cast(eos_in_sents, tf.int32) + # update model_kwargs + if use_xla: + model_kwargs = self._update_model_kwargs_for_xla_generation( + outputs, model_kwargs, current_pos, max_length ) + else: + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + # if we don't cache past key values we need the whole input + if model_kwargs.get("past", None) is None: + # let's throw out `past` since we don't want `None` tensors + model_kwargs.pop("past", None) - # unfinished_sequences is set to zero if eos in sentence - unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos + next_tokens = tf.reshape(generated.concat(), (-1, batch_size)) + next_tokens = tf.transpose(next_tokens[: current_pos[0]]) - # stop when each sentence is finished, or if we exceed the maximum length - if tf.math.reduce_max(unfinished_sequences) == 0: - break + return generated, finished_sequences, next_tokens, current_pos, model_kwargs + + # 5. run generation + # 1st generation step has to be run before to initialize `past` + generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn( + generated, finished_sequences, input_ids, current_pos, model_kwargs + ) + + # 2-to-n generation steps can then be run in autoregressive fashion + # only in case 1st generation step does NOT yield EOS token though + if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs): + maximum_iterations = max_length - seq_length - 1 + generated, _, _, current_pos, _ = tf.while_loop( + greedy_search_cond_fn, + greedy_search_body_fn, + (generated, finished_sequences, next_tokens, current_pos, model_kwargs), + maximum_iterations=maximum_iterations, + ) + + # 6. prepare outputs + output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size))) + + if not use_xla: + # cut for backward compatibility + output_ids = output_ids[:, : current_pos[0]] if return_dict_in_generate: if self.config.is_encoder_decoder: + # if model is an encoder-decoder, retrieve encoder attention weights + # and hidden states + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + scores = tuple(scores) if scores is not None else None + decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None + cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None + decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None + return TFGreedySearchEncoderDecoderOutput( - sequences=input_ids, + sequences=output_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -1968,13 +2035,13 @@ class TFGenerationMixin: ) else: return TFGreedySearchDecoderOnlyOutput( - sequences=input_ids, + sequences=output_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: - return input_ids + return output_ids def sample( self, diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 98f78e16da..f17504cbd1 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -18,7 +18,9 @@ from dataclasses import dataclass from typing import List, Optional, Tuple +import numpy as np import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from ...activations_tf import get_tf_activation from ...file_utils import ( @@ -851,7 +853,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): def set_output_embeddings(self, value): self.set_input_embeddings(value) - def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs): + def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs): # TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2 # tests will need to be fixed after the change @@ -859,7 +861,81 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): if past: inputs = tf.expand_dims(inputs[:, -1], -1) - return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache} + # TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left + # for a future PR to not change too many things for now. + # All statements in this if case apply for both xla and non-xla (as they already do in PyTorch) + position_ids = None + attention_mask = None + if use_xla: + attention_mask = kwargs.get("attention_mask", None) + if past is not None and attention_mask is not None: + position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1 + elif attention_mask is not None: + position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True) + + return { + "input_ids": inputs, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past, + "use_cache": use_cache, + } + + def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length): + # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored + # quite some duplicated code patterns it seems + # also the `attention_mask` is currently used in a somewhat hacky to + # correctly influence the `past_key_values` - not sure if this is the way to go + # Let's keep that for a future PR. + past = outputs.past_key_values + is_past_initialized = model_kwargs.pop("past", None) is not None + attention_mask = model_kwargs.pop("attention_mask") + batch_size = attention_mask.shape[0] + + if not is_past_initialized: + # past[0].shape[3] is seq_length of prompt + num_padding_values = max_length - past[0].shape[3] - 1 + + padding_values = np.zeros((5, 2), dtype=np.int32) + padding_values[3, 1] = num_padding_values + padding_values = tf.constant(padding_values) + + new_past = list(past) + for i in range(len(past)): + new_past[i] = tf.pad(past[i], padding_values) + + # Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids + attention_mask = tf.concat( + [ + attention_mask, + tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype), + tf.ones((batch_size, 1), dtype=attention_mask.dtype), + ], + axis=1, + ) + else: + new_past = [None for _ in range(len(past))] + slice_start_base = tf.constant([0, 0, 0, 1, 0]) + attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype) + # correct 5 here + new_past_index = current_pos - 1 + + for i in range(len(past)): + update_slice = past[i][:, :, :, -1:] + # Write the last slice to the first open location in the padded past array + # and then truncate the last slice off the array + new_past[i] = dynamic_update_slice( + past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index + ) + + update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index + attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start) + + # set `attention_mask` and `past` + model_kwargs["attention_mask"] = attention_mask + model_kwargs["past"] = tuple(new_past) + + return model_kwargs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index 53a2186425..09b7e5991c 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -1309,9 +1309,13 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss min_length=min_length, eos_token_id=eos_token_id, ) - # TODO(Patrick) clean-up once generate is fully cleaned up model_kwargs["attention_mask"] = context_attention_mask - # TODO(Patrick) remove once generate is fully cleaned up + + if model_kwargs.get("encoder_attentions", None) is None: + model_kwargs.pop("encoder_attentions", None) + if model_kwargs.get("encoder_hidden_states", None) is None: + model_kwargs.pop("encoder_hidden_states", None) + model_kwargs.pop("output_hidden_states", None) model_kwargs.pop("output_attentions", None) model_kwargs.pop("output_scores", None) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 019efed707..9928f38413 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -21,7 +21,9 @@ import math import warnings from typing import Tuple +import numpy as np import tensorflow as tf +from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice from ...activations_tf import get_tf_activation from ...file_utils import ( @@ -1545,6 +1547,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling input_ids, past=None, attention_mask=None, + decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, use_cache=None, @@ -1562,11 +1565,76 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "use_cache": use_cache, } + def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length): + # TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored + # quite some duplicated code patterns it seems + # also the `attention_mask` is currently used in a somewhat hacky to + # correctly influence the `past_key_values` - not sure if this is the way to go + # Let's keep that for a future PR. + past = outputs.past_key_values + is_past_initialized = model_kwargs.pop("past", None) is not None + decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None) + batch_size = past[0][0].shape[0] + + if not is_past_initialized: + # past[0].shape[3] is seq_length of prompt + num_padding_values = max_length - past[0][0].shape[2] - 1 + + padding_values = np.zeros((4, 2), dtype=np.int32) + padding_values[2, 1] = num_padding_values + padding_values = tf.constant(padding_values) + + new_past = () + for past_layer in past: + new_past_layer = list(past_layer) + for i in range(len(new_past_layer[:2])): + new_past_layer[i] = tf.pad(past_layer[i], padding_values) + new_past += (tuple(new_past_layer),) + + # 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids + decoder_attention_mask = tf.concat( + [ + tf.ones((batch_size, 1), dtype=tf.int32), + tf.zeros((batch_size, num_padding_values), dtype=tf.int32), + tf.ones((batch_size, 1), dtype=tf.int32), + ], + axis=1, + ) + else: + slice_start_base = tf.constant([0, 0, 1, 0]) + decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype) + # correct 5 here + new_past_index = current_pos - 1 + + new_past = () + for past_layer in past: + new_past_layer = list(past_layer) + for i in range(len(new_past_layer[:2])): + update_slice = past_layer[i][:, :, -1:] + # Write the last slice to the first open location in the padded past array + # and then truncate the last slice off the array + new_past_layer[i] = dynamic_update_slice( + past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index + ) + new_past += (tuple(new_past_layer),) + + update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index + decoder_attention_mask = dynamic_update_slice( + decoder_attention_mask, decoder_attention_mask_update_slice, update_start + ) + + # set `attention_mask` and `past` + model_kwargs["decoder_attention_mask"] = decoder_attention_mask + model_kwargs["past"] = new_past + + return model_kwargs + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): return self._shift_right(labels) diff --git a/tests/gpt2/test_modeling_gpt2.py b/tests/gpt2/test_modeling_gpt2.py index e80c924310..cea36400b2 100644 --- a/tests/gpt2/test_modeling_gpt2.py +++ b/tests/gpt2/test_modeling_gpt2.py @@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): else: model.gradient_checkpointing_disable() model.to(torch_device) - input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog + + # The dog + input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) + + # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + # fmt: off expected_output_ids = [ - 464, - 3290, - 373, - 1043, - 287, - 257, - 2214, - 1474, - 262, - 16246, - 286, - 2688, - 290, - 2688, - 27262, - 13, - 198, - 198, - 464, - 3290, - ] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + 464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290, + ] + # fmt: on output_ids = model.generate(input_ids, do_sample=False) if verify_outputs: self.assertListEqual(output_ids[0].tolist(), expected_output_ids) diff --git a/tests/gpt2/test_modeling_tf_gpt2.py b/tests/gpt2/test_modeling_tf_gpt2.py index 4bc1c876e0..6ff35b6be3 100644 --- a/tests/gpt2/test_modeling_tf_gpt2.py +++ b/tests/gpt2/test_modeling_tf_gpt2.py @@ -294,6 +294,21 @@ class TFGPT2ModelTester: result = model(inputs) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_gpt2_xla_generate(self, config, input_ids, *args): + config.eos_token_id = None + config.max_length = 10 + model = TFGPT2LMHeadModel(config=config) + + # make sure there are no pad tokens in prompt + input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1) + + generated = model.generate(input_ids) + + generate_xla = tf.function(model.generate, jit_compile=True) + generated_xla = generate_xla(input_ids) + + self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist()) + def create_and_check_gpt2_double_head( self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args ): @@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs) + def test_gpt2_xla_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs) + def test_gpt2_double_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs) @@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase): # fmt: on output_ids = model.generate(input_ids, do_sample=False) self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) + + @slow + def test_lm_generate_gpt2_xla(self): + """This test gives the exact same results as the non-xla test above""" + model = TFGPT2LMHeadModel.from_pretrained("gpt2") + input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog + + # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog + # fmt: off + expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290] + # fmt: on + xla_generate = tf.function(model.generate, jit_compile=True) + + output_ids = xla_generate(input_ids, do_sample=False) + self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) diff --git a/tests/t5/test_modeling_tf_t5.py b/tests/t5/test_modeling_tf_t5.py index 5abf66f4c2..f7397cc615 100644 --- a/tests/t5/test_modeling_tf_t5.py +++ b/tests/t5/test_modeling_tf_t5.py @@ -227,6 +227,23 @@ class TFT5ModelTester: # test that outputs are equal for slice tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3) + def create_and_check_t5_xla_generate(self, config, input_ids, *args): + config.eos_token_id = None + config.max_length = 10 + config.do_sample = False + config.num_beams = 1 + model = TFT5ForConditionalGeneration(config=config) + + # make sure there are no pad tokens in prompt + input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id + 5) + + generated = model.generate(input_ids) + + generate_xla = tf.function(model.generate, jit_compile=True) + generated_xla = generate_xla(input_ids) + + self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist()) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, token_labels) = config_and_inputs @@ -280,6 +297,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs) + def test_t5_model_xla_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs) + def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -454,6 +475,27 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase): @require_sentencepiece @require_tokenizers class TFT5GenerationIntegrationTests(unittest.TestCase): + @slow + def test_greedy_xla_generate_simple(self): + model = TFT5ForConditionalGeneration.from_pretrained("t5-small") + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + sentence = "Translate English to German: Today is a beautiful day." + input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids + + xla_generate = tf.function(model.generate, jit_compile=True) + + output_ids = model.generate(input_ids) + output_ids_xla = xla_generate(input_ids) + + output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True) + + expected_output_string = ["Heute ist ein schöner Tag."] + + self.assertListEqual(expected_output_string, output_strings) + self.assertListEqual(expected_output_string, output_strings_xla) + @slow def test_greedy_generate(self): model = TFT5ForConditionalGeneration.from_pretrained("t5-small")