TF XLA greedy generation (#15786)
* First attempt at TF XLA generation * Fix comments * Update XLA greedy generate with direct XLA calls * Support attention mask, prepare_inputs_for_generation no longer hardcoded for greedy * Handle position_ids correctly * make xla generate work for non xla case * force using xla generate * refactor * more fixes * finish cleaning * finish * finish * clean gpt2 tests * add gpt2 tests * correct more cases * up * finish * finish * more fixes * flake 8 stuff * final rag fix * Update src/transformers/models/rag/modeling_tf_rag.py * finish t5 as well * finish * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
e5bc438cc8
commit
cd4c5c9060
|
@ -260,7 +260,6 @@ class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
|
|||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
|
||||
|
||||
score_penalties = self._create_score_penalties(input_ids, scores)
|
||||
|
||||
scores = tf.math.multiply(scores, score_penalties)
|
||||
|
|
|
@ -1484,9 +1484,12 @@ class TFGenerationMixin:
|
|||
batch_size = input_ids.shape[0]
|
||||
|
||||
# 3. Prepare other model kwargs
|
||||
model_kwargs["output_attentions"] = output_attentions
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
if output_attentions is not None:
|
||||
model_kwargs["output_attentions"] = output_attentions
|
||||
if output_hidden_states is not None:
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
if use_cache is not None:
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
|
||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||
|
||||
|
@ -1533,7 +1536,6 @@ class TFGenerationMixin:
|
|||
raise ValueError(
|
||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||
)
|
||||
|
||||
# 8. run greedy search
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
|
@ -1545,7 +1547,6 @@ class TFGenerationMixin:
|
|||
return_dict_in_generate=return_dict_in_generate,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# 8. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
|
@ -1571,15 +1572,13 @@ class TFGenerationMixin:
|
|||
**model_kwargs,
|
||||
)
|
||||
|
||||
# TODO(Matt, Joao, Patrick) - add more sub-generation methods here
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
input_ids: tf.Tensor,
|
||||
pad_token_id: int,
|
||||
) -> tf.Tensor:
|
||||
# prepare `attention_mask` if not passed
|
||||
if (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
|
||||
if (pad_token_id is not None) and tf.math.reduce_any(input_ids == pad_token_id):
|
||||
return tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
|
||||
else:
|
||||
return tf.ones(input_ids.shape[:2], dtype=tf.int32)
|
||||
|
@ -1717,6 +1716,14 @@ class TFGenerationMixin:
|
|||
|
||||
return model_kwargs
|
||||
|
||||
def _update_model_kwargs_for_xla_generation(
|
||||
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], current_pos: tf.Tensor, max_length: int
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is not compileable with XLA at the moment. You should implement a "
|
||||
"`_update_model_kwargs_for_xla_generation` in the respective modeling file for XLA-compatible generation."
|
||||
)
|
||||
|
||||
def _get_logits_warper(
|
||||
self,
|
||||
top_k: Optional[int] = None,
|
||||
|
@ -1773,7 +1780,7 @@ class TFGenerationMixin:
|
|||
processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
||||
if bad_words_ids is not None:
|
||||
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
|
||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||
if min_length is not None and eos_token_id is not None and min_length > 0:
|
||||
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
|
||||
|
||||
return processors
|
||||
|
@ -1858,7 +1865,8 @@ class TFGenerationMixin:
|
|||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
```"""
|
||||
# init values
|
||||
|
||||
# 1. init greedy_search values
|
||||
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
|
||||
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
|
@ -1871,94 +1879,153 @@ class TFGenerationMixin:
|
|||
return_dict_in_generate = (
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
use_xla = not tf.executing_eagerly()
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
# 2. init `attentions`, `hidden_states`, and `scores` tuples
|
||||
scores = [] if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
# 3. init tensors to use for "xla-compileable" generate function
|
||||
# define bsz, seq_length
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = tf.ones_like(input_ids[:, 0])
|
||||
cur_len = input_ids.shape[-1]
|
||||
# initialize `generated`, `finished_sequences`, and `current_pos`
|
||||
generated = tf.TensorArray(
|
||||
element_shape=(batch_size,),
|
||||
dtype=tf.int32,
|
||||
dynamic_size=False,
|
||||
size=max_length,
|
||||
clear_after_read=False,
|
||||
)
|
||||
|
||||
while cur_len < max_length:
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
# write prompt to generated
|
||||
for i in range(seq_length):
|
||||
generated = generated.write(i, input_ids[:, i])
|
||||
|
||||
# forward pass to get next token
|
||||
finished_sequences = tf.zeros((batch_size,), dtype=tf.bool)
|
||||
current_pos = tf.ones(shape=(1,), dtype=tf.int32) * seq_length
|
||||
|
||||
# 4. define "xla-compile-able" stop-condition and auto-regressive function
|
||||
# define condition fn
|
||||
def greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
||||
"""state termination condition fn."""
|
||||
return ~tf.reduce_all(finished_sequences)
|
||||
|
||||
# define condition fn
|
||||
def greedy_search_body_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
||||
"""state update fn."""
|
||||
# TODO(pvp, Joao) - `use_xla` can be removed here as soon as `position_ids` are corrected for the non-xla case in gpt2's `prepare_inputs_for_generation`.
|
||||
model_inputs = self.prepare_inputs_for_generation(next_tokens, use_xla=use_xla, **model_kwargs)
|
||||
# forward pass to get next token logits
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
next_token_logits = outputs.logits[:, -1]
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if not use_xla and return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_logits,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
scores.append(next_token_logits)
|
||||
if output_attentions and self.config.is_encoder_decoder:
|
||||
decoder_attentions.append(outputs.decoder_attentions)
|
||||
elif output_attentions and not self.config.is_encoder_decoder:
|
||||
decoder_attentions.append(outputs.attentions)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
cross_attentions.append(outputs.cross_attentions)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
if output_hidden_states and self.config.is_encoder_decoder:
|
||||
decoder_hidden_states.append(outputs.decoder_hidden_states)
|
||||
elif output_hidden_states and self.config.is_encoder_decoder:
|
||||
decoder_hidden_states.append(outputs.hidden_states)
|
||||
|
||||
# pre-process distribution
|
||||
# TODO(pvp, joao, matt) - all the logits processors need to be adapted
|
||||
# to be XLA compatible
|
||||
input_ids = None
|
||||
if not use_xla:
|
||||
input_ids = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
input_ids = tf.transpose(input_ids[: current_pos[0]])
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
# argmax
|
||||
next_tokens = tf.cast(tf.argmax(next_tokens_scores, axis=-1), tf.int32)
|
||||
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
unfinished_seq = 1 - tf.cast(finished_sequences, tf.int32)
|
||||
next_tokens = next_tokens * unfinished_seq + pad_token_id * (1 - unfinished_seq)
|
||||
finished_sequences = finished_sequences | (next_tokens == eos_token_id)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = tf.concat([input_ids, next_tokens[:, None]], axis=-1)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
cur_len = cur_len + 1
|
||||
# update `generated` and `current_pos`
|
||||
generated = generated.write(current_pos[0], next_tokens)
|
||||
next_tokens = next_tokens[:, None]
|
||||
current_pos += 1
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
eos_in_sents = next_tokens == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos
|
||||
is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
|
||||
unfinished_sequences, tf.cast(eos_in_sents, tf.int32)
|
||||
# update model_kwargs
|
||||
if use_xla:
|
||||
model_kwargs = self._update_model_kwargs_for_xla_generation(
|
||||
outputs, model_kwargs, current_pos, max_length
|
||||
)
|
||||
else:
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
# if we don't cache past key values we need the whole input
|
||||
if model_kwargs.get("past", None) is None:
|
||||
# let's throw out `past` since we don't want `None` tensors
|
||||
model_kwargs.pop("past", None)
|
||||
|
||||
# unfinished_sequences is set to zero if eos in sentence
|
||||
unfinished_sequences -= is_sents_unfinished_and_token_to_add_is_eos
|
||||
next_tokens = tf.reshape(generated.concat(), (-1, batch_size))
|
||||
next_tokens = tf.transpose(next_tokens[: current_pos[0]])
|
||||
|
||||
# stop when each sentence is finished, or if we exceed the maximum length
|
||||
if tf.math.reduce_max(unfinished_sequences) == 0:
|
||||
break
|
||||
return generated, finished_sequences, next_tokens, current_pos, model_kwargs
|
||||
|
||||
# 5. run generation
|
||||
# 1st generation step has to be run before to initialize `past`
|
||||
generated, finished_sequences, next_tokens, current_pos, model_kwargs = greedy_search_body_fn(
|
||||
generated, finished_sequences, input_ids, current_pos, model_kwargs
|
||||
)
|
||||
|
||||
# 2-to-n generation steps can then be run in autoregressive fashion
|
||||
# only in case 1st generation step does NOT yield EOS token though
|
||||
if greedy_search_cond_fn(generated, finished_sequences, next_tokens, current_pos, model_kwargs):
|
||||
maximum_iterations = max_length - seq_length - 1
|
||||
generated, _, _, current_pos, _ = tf.while_loop(
|
||||
greedy_search_cond_fn,
|
||||
greedy_search_body_fn,
|
||||
(generated, finished_sequences, next_tokens, current_pos, model_kwargs),
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
||||
# 6. prepare outputs
|
||||
output_ids = tf.transpose(tf.reshape(generated.concat(), (-1, batch_size)))
|
||||
|
||||
if not use_xla:
|
||||
# cut for backward compatibility
|
||||
output_ids = output_ids[:, : current_pos[0]]
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights
|
||||
# and hidden states
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
scores = tuple(scores) if scores is not None else None
|
||||
decoder_attentions = tuple(decoder_attentions) if decoder_attentions is not None else None
|
||||
cross_attentions = tuple(cross_attentions) if cross_attentions is not None else None
|
||||
decoder_hidden_states = tuple(decoder_hidden_states) if decoder_hidden_states is not None else None
|
||||
|
||||
return TFGreedySearchEncoderDecoderOutput(
|
||||
sequences=input_ids,
|
||||
sequences=output_ids,
|
||||
scores=scores,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
|
@ -1968,13 +2035,13 @@ class TFGenerationMixin:
|
|||
)
|
||||
else:
|
||||
return TFGreedySearchDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
sequences=output_ids,
|
||||
scores=scores,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
return output_ids
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
|
@ -851,7 +853,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
def set_output_embeddings(self, value):
|
||||
self.set_input_embeddings(value)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, use_xla=False, **kwargs):
|
||||
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
|
||||
# tests will need to be fixed after the change
|
||||
|
||||
|
@ -859,7 +861,81 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
|
|||
if past:
|
||||
inputs = tf.expand_dims(inputs[:, -1], -1)
|
||||
|
||||
return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache}
|
||||
# TODO(pvp, Joao) - this `if use_xla` statement can be removed, but is left
|
||||
# for a future PR to not change too many things for now.
|
||||
# All statements in this if case apply for both xla and non-xla (as they already do in PyTorch)
|
||||
position_ids = None
|
||||
attention_mask = None
|
||||
if use_xla:
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if past is not None and attention_mask is not None:
|
||||
position_ids = tf.reduce_sum(attention_mask, axis=1, keepdims=True) - 1
|
||||
elif attention_mask is not None:
|
||||
position_ids = tf.math.cumsum(attention_mask, axis=1, exclusive=True)
|
||||
|
||||
return {
|
||||
"input_ids": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
|
||||
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
|
||||
# quite some duplicated code patterns it seems
|
||||
# also the `attention_mask` is currently used in a somewhat hacky to
|
||||
# correctly influence the `past_key_values` - not sure if this is the way to go
|
||||
# Let's keep that for a future PR.
|
||||
past = outputs.past_key_values
|
||||
is_past_initialized = model_kwargs.pop("past", None) is not None
|
||||
attention_mask = model_kwargs.pop("attention_mask")
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
if not is_past_initialized:
|
||||
# past[0].shape[3] is seq_length of prompt
|
||||
num_padding_values = max_length - past[0].shape[3] - 1
|
||||
|
||||
padding_values = np.zeros((5, 2), dtype=np.int32)
|
||||
padding_values[3, 1] = num_padding_values
|
||||
padding_values = tf.constant(padding_values)
|
||||
|
||||
new_past = list(past)
|
||||
for i in range(len(past)):
|
||||
new_past[i] = tf.pad(past[i], padding_values)
|
||||
|
||||
# Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
|
||||
attention_mask = tf.concat(
|
||||
[
|
||||
attention_mask,
|
||||
tf.zeros((batch_size, num_padding_values), dtype=attention_mask.dtype),
|
||||
tf.ones((batch_size, 1), dtype=attention_mask.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
new_past = [None for _ in range(len(past))]
|
||||
slice_start_base = tf.constant([0, 0, 0, 1, 0])
|
||||
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
|
||||
# correct 5 here
|
||||
new_past_index = current_pos - 1
|
||||
|
||||
for i in range(len(past)):
|
||||
update_slice = past[i][:, :, :, -1:]
|
||||
# Write the last slice to the first open location in the padded past array
|
||||
# and then truncate the last slice off the array
|
||||
new_past[i] = dynamic_update_slice(
|
||||
past[i][:, :, :, :-1], update_slice, slice_start_base * new_past_index
|
||||
)
|
||||
|
||||
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
|
||||
attention_mask = dynamic_update_slice(attention_mask, attention_mask_update_slice, update_start)
|
||||
|
||||
# set `attention_mask` and `past`
|
||||
model_kwargs["attention_mask"] = attention_mask
|
||||
model_kwargs["past"] = tuple(new_past)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
|
|
|
@ -1309,9 +1309,13 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
|||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
# TODO(Patrick) clean-up once generate is fully cleaned up
|
||||
model_kwargs["attention_mask"] = context_attention_mask
|
||||
# TODO(Patrick) remove once generate is fully cleaned up
|
||||
|
||||
if model_kwargs.get("encoder_attentions", None) is None:
|
||||
model_kwargs.pop("encoder_attentions", None)
|
||||
if model_kwargs.get("encoder_hidden_states", None) is None:
|
||||
model_kwargs.pop("encoder_hidden_states", None)
|
||||
|
||||
model_kwargs.pop("output_hidden_states", None)
|
||||
model_kwargs.pop("output_attentions", None)
|
||||
model_kwargs.pop("output_scores", None)
|
||||
|
|
|
@ -21,7 +21,9 @@ import math
|
|||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
|
@ -1545,6 +1547,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
use_cache=None,
|
||||
|
@ -1562,11 +1565,76 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
"past_key_values": past,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
def _update_model_kwargs_for_xla_generation(self, outputs, model_kwargs, current_pos, max_length):
|
||||
# TODO(Pvp, Joao, Matt) - this function can be cleaned a bit and refactored
|
||||
# quite some duplicated code patterns it seems
|
||||
# also the `attention_mask` is currently used in a somewhat hacky to
|
||||
# correctly influence the `past_key_values` - not sure if this is the way to go
|
||||
# Let's keep that for a future PR.
|
||||
past = outputs.past_key_values
|
||||
is_past_initialized = model_kwargs.pop("past", None) is not None
|
||||
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask", None)
|
||||
batch_size = past[0][0].shape[0]
|
||||
|
||||
if not is_past_initialized:
|
||||
# past[0].shape[3] is seq_length of prompt
|
||||
num_padding_values = max_length - past[0][0].shape[2] - 1
|
||||
|
||||
padding_values = np.zeros((4, 2), dtype=np.int32)
|
||||
padding_values[2, 1] = num_padding_values
|
||||
padding_values = tf.constant(padding_values)
|
||||
|
||||
new_past = ()
|
||||
for past_layer in past:
|
||||
new_past_layer = list(past_layer)
|
||||
for i in range(len(new_past_layer[:2])):
|
||||
new_past_layer[i] = tf.pad(past_layer[i], padding_values)
|
||||
new_past += (tuple(new_past_layer),)
|
||||
|
||||
# 1 one for decoder_start_token_id, Zeros for the currently-unfilled locations in the past tensor, ones for the actual input_ids
|
||||
decoder_attention_mask = tf.concat(
|
||||
[
|
||||
tf.ones((batch_size, 1), dtype=tf.int32),
|
||||
tf.zeros((batch_size, num_padding_values), dtype=tf.int32),
|
||||
tf.ones((batch_size, 1), dtype=tf.int32),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
slice_start_base = tf.constant([0, 0, 1, 0])
|
||||
decoder_attention_mask_update_slice = tf.ones((batch_size, 1), dtype=decoder_attention_mask.dtype)
|
||||
# correct 5 here
|
||||
new_past_index = current_pos - 1
|
||||
|
||||
new_past = ()
|
||||
for past_layer in past:
|
||||
new_past_layer = list(past_layer)
|
||||
for i in range(len(new_past_layer[:2])):
|
||||
update_slice = past_layer[i][:, :, -1:]
|
||||
# Write the last slice to the first open location in the padded past array
|
||||
# and then truncate the last slice off the array
|
||||
new_past_layer[i] = dynamic_update_slice(
|
||||
past_layer[i][:, :, :-1], update_slice, slice_start_base * new_past_index
|
||||
)
|
||||
new_past += (tuple(new_past_layer),)
|
||||
|
||||
update_start = tf.constant([0, 1], dtype=tf.int32) * new_past_index
|
||||
decoder_attention_mask = dynamic_update_slice(
|
||||
decoder_attention_mask, decoder_attention_mask_update_slice, update_start
|
||||
)
|
||||
|
||||
# set `attention_mask` and `past`
|
||||
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
||||
model_kwargs["past"] = new_past
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
|
||||
return self._shift_right(labels)
|
||||
|
||||
|
|
|
@ -660,29 +660,16 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||
else:
|
||||
model.gradient_checkpointing_disable()
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
|
||||
# The dog
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device)
|
||||
|
||||
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
# fmt: off
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290,
|
||||
]
|
||||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
if verify_outputs:
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
|
|
@ -294,6 +294,21 @@ class TFGPT2ModelTester:
|
|||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt2_xla_generate(self, config, input_ids, *args):
|
||||
config.eos_token_id = None
|
||||
config.max_length = 10
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
|
||||
# make sure there are no pad tokens in prompt
|
||||
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id - 1)
|
||||
|
||||
generated = model.generate(input_ids)
|
||||
|
||||
generate_xla = tf.function(model.generate, jit_compile=True)
|
||||
generated_xla = generate_xla(input_ids)
|
||||
|
||||
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
|
||||
|
||||
def create_and_check_gpt2_double_head(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||
):
|
||||
|
@ -393,6 +408,10 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_lm_head(*config_and_inputs)
|
||||
|
||||
def test_gpt2_xla_generate(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_xla_generate(*config_and_inputs)
|
||||
|
||||
def test_gpt2_double_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_double_head(*config_and_inputs)
|
||||
|
@ -513,3 +532,18 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
|||
# fmt: on
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_gpt2_xla(self):
|
||||
"""This test gives the exact same results as the non-xla test above"""
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
|
||||
# The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
# fmt: off
|
||||
expected_output_ids = [464, 3290, 373, 1043, 287, 257, 2214, 1474, 262, 16246, 286, 2688, 290, 2688, 27262, 13, 198, 198, 464, 3290]
|
||||
# fmt: on
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
output_ids = xla_generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
|
|
@ -227,6 +227,23 @@ class TFT5ModelTester:
|
|||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
|
||||
|
||||
def create_and_check_t5_xla_generate(self, config, input_ids, *args):
|
||||
config.eos_token_id = None
|
||||
config.max_length = 10
|
||||
config.do_sample = False
|
||||
config.num_beams = 1
|
||||
model = TFT5ForConditionalGeneration(config=config)
|
||||
|
||||
# make sure there are no pad tokens in prompt
|
||||
input_ids = tf.where(input_ids != config.pad_token_id, input_ids, config.pad_token_id + 5)
|
||||
|
||||
generated = model.generate(input_ids)
|
||||
|
||||
generate_xla = tf.function(model.generate, jit_compile=True)
|
||||
generated_xla = generate_xla(input_ids)
|
||||
|
||||
self.parent.assertListEqual(generated.numpy().tolist(), generated_xla.numpy().tolist())
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, token_labels) = config_and_inputs
|
||||
|
@ -280,6 +297,10 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_t5_model_xla_generate(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_xla_generate(*config_and_inputs)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -454,6 +475,27 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TFT5GenerationIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_greedy_xla_generate_simple(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
sentence = "Translate English to German: Today is a beautiful day."
|
||||
input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
output_ids_xla = xla_generate(input_ids)
|
||||
|
||||
output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
|
||||
|
||||
expected_output_string = ["Heute ist ein schöner Tag."]
|
||||
|
||||
self.assertListEqual(expected_output_string, output_strings)
|
||||
self.assertListEqual(expected_output_string, output_strings_xla)
|
||||
|
||||
@slow
|
||||
def test_greedy_generate(self):
|
||||
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
|
Loading…
Reference in New Issue