TF XLA greedy generation (#15786)

* First attempt at TF XLA generation

* Fix comments

* Update XLA greedy generate with direct XLA calls

* Support attention mask, prepare_inputs_for_generation no longer hardcoded for greedy

* Handle position_ids correctly

* make xla generate work for non xla case

* force using xla generate

* refactor

* more fixes

* finish cleaning

* finish

* finish

* clean gpt2 tests

* add gpt2 tests

* correct more cases

* up

* finish

* finish

* more fixes

* flake 8 stuff

* final rag fix

* Update src/transformers/models/rag/modeling_tf_rag.py

* finish t5 as well

* finish

* Update src/transformers/generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Matt 2022-03-15 13:19:20 +00:00 committed by GitHub
parent e5bc438cc8
commit cd4c5c9060
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 370 additions and 93 deletions

View File

@ -260,7 +260,6 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores)
scores = tf.math.multiply(scores, score_penalties)

View File

@ -1484,9 +1484,12 @@ class TFGenerationMixin:
batch_size = input_ids.shape[0]
# 3. Prepare other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
if output_attentions is not None:
model_kwargs["output_attentions"] = output_attentions
if output_hidden_states is not None:
model_kwargs["output_hidden_states"] = output_hidden_states
if use_cache is not None:
model_kwargs["use_cache"] = use_cache
requires_attention_mask = "encoder_outputs" not in model_kwargs
@ -1533,7 +1536,6 @@ class TFGenerationMixin:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# 8. run greedy search
return self.greedy_search(
input_ids,
@ -1545,7 +1547,6 @@ class TFGenerationMixin:
return_dict_in_generate=return_dict_in_generate,
**model_kwargs,
)
elif is_sample_gen_mode:
# 8. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
@ -1571,15 +1572,13 @@ class TFGenerationMixin:
**model_kwargs,
)
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
def _prepare_attention_mask_for_generation(
self,
input_ids: tf.Tensor,
pad_token_id: int,
) -> tf.Tensor:
# prepare `attention_mask` if not passed
if (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id):
return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
@ -1717,6 +1716,14 @@ class TFGenerationMixin:
return model_kwargs
def _update_model_kwargs_for_xla_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], current_pos: tf.Tensor, max_length: int
) -> Dict[str, Any]:
raise NotImplementedError(
f"{self.__class__} is not compileable with XLA at the moment. You should implement a "
"`_update_model_kwargs_for_xla_generation` in the respective modeling file for XLA-compatible generation."
)
def _get_logits_warper(
self,
top_k: Optional[int] = None,
@ -1773,7 +1780,7 @@ class TFGenerationMixin:
processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if bad_words_ids is not None:
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1:
if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
return processors
@ -1858,7 +1865,8 @@ class TFGenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
```"""
# init values
# 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
@ -1871,94 +1879,153 @@ class TFGenerationMixin:
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
use_xla = not tf.executing_eagerly()
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# 3. init tensors to use for "xla-compileable" generate function
# define bsz, seq_length
batch_size, seq_length = input_ids.shape
# keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1]
# initialize `generated`, `finished_sequences`, and `current_pos`
generated = tf.TensorArray(
element_shape=(batch_size,),
dtype=tf.int32,
dynamic_size=False,
size=max_length,
clear_after_read=False,
)
while cur_len < max_length:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# write prompt to generated
for i in range(seq_length):
generated = generated.write(i, input_ids[:, i])
# forward pass to get next token
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length
# 4. define "xla-compile-able" stop-condition and auto-regressive function
# define condition fn
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
"""state termination condition fn."""
return ~tf.reduce_all(finished_sequences)
# define condition fn
def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
"""state update fn."""
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
# forward pass to get next token logits
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
next_token_logits = outputs.logits[:, -1, :]
next_token_logits = outputs.logits[:, -1]
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if not use_xla and return_dict_in_generate:
if output_scores:
scores += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
scores.append(next_token_logits)
if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder:
decoder_attentions.append(outputs.attentions)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
cross_attentions.append(outputs.cross_attentions)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
if output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(outputs.decoder_hidden_states)
elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(outputs.hidden_states)
# pre-process distribution
# TODO(pvp, joao, matt) - all the logits processors need to be adapted
# to be XLA compatible
input_ids = None
if not use_xla:
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
input_ids = tf.transpose(input_ids[: current_pos[0]])
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = tf.cast(tf.argmax(next_tokens_scores, axis=-1), tf.int32)
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
# update generated ids, model inputs, and length for next step
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
cur_len = cur_len + 1
# update `generated` and `current_pos`
generated = generated.write(current_pos[0], next_tokens)
next_tokens = next_tokens[:, None]
current_pos += 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
eos_in_sents = next_tokens == eos_token_id
# if sentence is unfinished and the token to add is eos
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
# update model_kwargs
if use_xla:
model_kwargs = self._update_model_kwargs_for_xla_generation(
outputs, model_kwargs, current_pos, max_length
)
else:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# if we don't cache past key values we need the whole input
if model_kwargs.get("past", None) is None:
# let's throw out `past` since we don't want `None` tensors
model_kwargs.pop("past", None)
# unfinished_sequences is set to zero if eos in sentence
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
next_tokens = tf.transpose(next_tokens[: current_pos[0]])
# stop when each sentence is finished, or if we exceed the maximum length
if tf.math.reduce_max(unfinished_sequences) == 0:
break
return generated, finished_sequences, next_tokens, current_pos, model_kwargs
# 5. run generation
# 1st generation step has to be run before to initialize `past`
generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn(
generated, finished_sequences, input_ids, current_pos, model_kwargs
)
# 2-to-n generation steps can then be run in autoregressive fashion
# only in case 1st generation step does NOT yield EOS token though
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
maximum_iterations = max_length - seq_length - 1
generated, _, _, current_pos, _ = tf.while_loop(
greedy_search_cond_fn,
greedy_search_body_fn,
(generated, finished_sequences, next_tokens, current_pos, model_kwargs),
maximum_iterations=maximum_iterations,
)
# 6. prepare outputs
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
if not use_xla:
# cut for backward compatibility
output_ids = output_ids[:, : current_pos[0]]
if return_dict_in_generate:
if self.config.is_encoder_decoder:
# if model is an encoder-decoder, retrieve encoder attention weights
# and hidden states
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
scores = tuple(scores) if scores is not None else None
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
return TFGreedySearchEncoderDecoderOutput(
sequences=input_ids,
sequences=output_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
@ -1968,13 +2035,13 @@ class TFGenerationMixin:
)
else:
return TFGreedySearchDecoderOnlyOutput(
sequences=input_ids,
sequences=output_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
return output_ids
def sample(
self,

View File

@ -18,7 +18,9 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ...activations_tf import get_tf_activation
from ...file_utils import (
@ -851,7 +853,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, value):
self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change
@ -859,7 +861,81 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache}
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
# for a future PR to not change too many things for now.
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
position_ids = None
attention_mask = None
if use_xla:
attention_mask = kwargs.get("attention_mask", None)
if past is not None and attention_mask is not None:
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
elif attention_mask is not None:
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
return {
"input_ids": inputs,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past,
"use_cache": use_cache,
}
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
attention_mask = model_kwargs.pop("attention_mask")
batch_size = attention_mask.shape[0]
if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
num_padding_values = max_length - past[0].shape[3] - 1
padding_values = np.zeros((5, 2), dtype=np.int32)
padding_values[3, 1] = num_padding_values
padding_values = tf.constant(padding_values)
new_past = list(past)
for i in range(len(past)):
new_past[i] = tf.pad(past[i], padding_values)
# Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
attention_mask = tf.concat(
[
attention_mask,
tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype),
tf.ones((batch_size, 1), dtype=attention_mask.dtype),
],
axis=1,
)
else:
new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1
for i in range(len(past)):
update_slice = past[i][:, :, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past[i] = dynamic_update_slice(
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
)
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start)
# set `attention_mask` and `past`
model_kwargs["attention_mask"] = attention_mask
model_kwargs["past"] = tuple(new_past)
return model_kwargs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(

View File

@ -1309,9 +1309,13 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
min_length=min_length,
eos_token_id=eos_token_id,
)
# TODO(Patrick) clean-up once generate is fully cleaned up
model_kwargs["attention_mask"] = context_attention_mask
# TODO(Patrick) remove once generate is fully cleaned up
if model_kwargs.get("encoder_attentions", None) is None:
model_kwargs.pop("encoder_attentions", None)
if model_kwargs.get("encoder_hidden_states", None) is None:
model_kwargs.pop("encoder_hidden_states", None)
model_kwargs.pop("output_hidden_states", None)
model_kwargs.pop("output_attentions", None)
model_kwargs.pop("output_scores", None)

View File

@ -21,7 +21,9 @@ import math
import warnings
from typing import Tuple
import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
from ...activations_tf import get_tf_activation
from ...file_utils import (
@ -1545,6 +1547,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
input_ids,
past=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
use_cache=None,
@ -1562,11 +1565,76 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"use_cache": use_cache,
}
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
# quite some duplicated code patterns it seems
# also the `attention_mask` is currently used in a somewhat hacky to
# correctly influence the `past_key_values` - not sure if this is the way to go
# Let's keep that for a future PR.
past = outputs.past_key_values
is_past_initialized = model_kwargs.pop("past", None) is not None
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
batch_size = past[0][0].shape[0]
if not is_past_initialized:
# past[0].shape[3] is seq_length of prompt
num_padding_values = max_length - past[0][0].shape[2] - 1
padding_values = np.zeros((4, 2), dtype=np.int32)
padding_values[2, 1] = num_padding_values
padding_values = tf.constant(padding_values)
new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
new_past += (tuple(new_past_layer),)
# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
decoder_attention_mask = tf.concat(
[
tf.ones((batch_size, 1), dtype=tf.int32),
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
tf.ones((batch_size, 1), dtype=tf.int32),
],
axis=1,
)
else:
slice_start_base = tf.constant([0, 0, 1, 0])
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
# correct 5 here
new_past_index = current_pos - 1
new_past = ()
for past_layer in past:
new_past_layer = list(past_layer)
for i in range(len(new_past_layer[:2])):
update_slice = past_layer[i][:, :, -1:]
# Write the last slice to the first open location in the padded past array
# and then truncate the last slice off the array
new_past_layer[i] = dynamic_update_slice(
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
)
new_past += (tuple(new_past_layer),)
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
decoder_attention_mask = dynamic_update_slice(
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
)
# set `attention_mask` and `past`
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["past"] = new_past
return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)

View File

@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
else:
model.gradient_checkpointing_disable()
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# The dog
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device)
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [
464,
3290,
373,
1043,
287,
257,
2214,
1474,
262,
16246,
286,
2688,
290,
2688,
27262,
13,
198,
198,
464,
3290,
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,
]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
if verify_outputs:
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)

View File

@ -294,6 +294,21 @@ class TFGPT2ModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
model = TFGPT2LMHeadModel(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def create_and_check_gpt2_double_head(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
):
@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
def test_gpt2_xla_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
def test_gpt2_double_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
@slow
def test_lm_generate_gpt2_xla(self):
"""This test gives the exact same results as the non-xla test above"""
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
# fmt: off
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
# fmt: on
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = xla_generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)

View File

@ -227,6 +227,23 @@ class TFT5ModelTester:
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_t5_xla_generate(self, config, input_ids, *args):
config.eos_token_id = None
config.max_length = 10
config.do_sample = False
config.num_beams = 1
model = TFT5ForConditionalGeneration(config=config)
# make sure there are no pad tokens in prompt
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id + 5)
generated = model.generate(input_ids)
generate_xla = tf.function(model.generate, jit_compile=True)
generated_xla = generate_xla(input_ids)
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
@ -280,6 +297,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
def test_t5_model_xla_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -454,6 +475,27 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFT5GenerationIntegrationTests(unittest.TestCase):
@slow
def test_greedy_xla_generate_simple(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
sentence = "Translate English to German: Today is a beautiful day."
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
xla_generate = tf.function(model.generate, jit_compile=True)
output_ids = model.generate(input_ids)
output_ids_xla = xla_generate(input_ids)
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
expected_output_string = ["Heute ist ein schöner Tag."]
self.assertListEqual(expected_output_string, output_strings)
self.assertListEqual(expected_output_string, output_strings_xla)
@slow
def test_greedy_generate(self):
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")