Fix T5 and BART for TF (#9063)

* Fix T5 for graphe compilation+execution

* Fix BART

* Fix import

* Fix naming

* fix attribute name

* Oops

* fix import

* fix tests

* fix tests

* Update test

* Add mising import

* Address Patrick's comments

* Style

* Address Patrick's comment
This commit is contained in:
Julien Plu 2020-12-14 18:47:00 +01:00 committed by GitHub
parent a9c8bff724
commit df3f4d2aef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 151 additions and 166 deletions

View File

@ -91,8 +91,6 @@ TensorFlow loss functions
TensorFlow Helper Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: transformers.modeling_tf_utils.cast_bool_to_primitive
.. autofunction:: transformers.modeling_tf_utils.get_initializer
.. autofunction:: transformers.modeling_tf_utils.keras_serializable

View File

@ -51,7 +51,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
tf_name = tf_name[1:] # Remove level zero
# Some weights have a single name withtout "/" such as final_logits_bias in BART
if len(tf_name) > 1:
tf_name = tf_name[1:] # Remove level zero
# When should we transpose the weights
transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name)

View File

@ -354,7 +354,7 @@ def input_processing(func, config, input_ids, **kwargs):
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
if isinstance(input_ids, (tuple, list)):
for i, input in enumerate(input_ids):
@ -372,7 +372,7 @@ def input_processing(func, config, input_ids, **kwargs):
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
)
elif isinstance(input_ids, (dict, BatchEncoding)):
if "inputs" in input_ids:
@ -399,13 +399,13 @@ def input_processing(func, config, input_ids, **kwargs):
)
continue
else:
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(input_ids, tf.Tensor) or input_ids is None:
output[parameter_names[0]] = input_ids
else:
raise ValueError(
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
)
for name in parameter_names:
@ -1366,31 +1366,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
"""
Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
Args:
bool_variable (:obj:`Union[tf.Tensor, bool]`):
The variable to convert to a boolean.
default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
The default value to use in case the tensor has no numpy attribute.
Returns:
:obj:`bool`: The converted value.
"""
# if bool variable is tensor and has numpy value
if tf.is_tensor(bool_variable):
if hasattr(bool_variable, "numpy"):
return bool(bool_variable.numpy())
elif default_tensor_to_true:
return True
# else variable is bool
return bool_variable
class TFWrappedEmbeddings:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with

View File

@ -41,7 +41,6 @@ from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
cast_bool_to_primitive,
input_processing,
keras_serializable,
shape_list,
@ -258,9 +257,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
assert shape_list(x) == shape_list(
residual
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}"
tf.debugging.assert_equal(
shape_list(x),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}",
)
x = self.dropout(x, training=training)
x = residual + x
if not self.normalize_before:
@ -295,9 +296,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.layerdrop = config.encoder_layerdrop
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
@ -328,7 +326,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
if config.add_final_layer_norm
else None
)
self.return_dict = config.return_dict
def call(
self,
@ -355,10 +352,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
# check attention mask and invert
if attention_mask is not None:
assert (
@ -546,9 +539,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
def call(
@ -565,14 +555,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
return_dict=None,
training=False,
):
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
if use_cache:
assert not training, "Training + use cache are incompatible"
# check attention mask and invert
use_cache = cast_bool_to_primitive(use_cache)
if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask)
@ -1046,7 +1029,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight(
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
def resize_token_embeddings(self, new_num_tokens):

View File

@ -32,12 +32,16 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
input_processing,
keras_serializable,
shape_list,
@ -311,7 +315,7 @@ class TFT5Attention(tf.keras.layers.Layer):
)
# to cope with keras serialization
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
if self.is_decoder and use_cache:
present_key_value_state = (key_states, value_states)
else:
present_key_value_state = None
@ -594,6 +598,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
) -> Tuple:
@ -610,6 +615,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
@ -713,10 +719,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
assert inputs["head_mask"] is None, "Head mask not supported"
inputs["head_mask"] = [None] * self.num_hidden_layers
present_key_value_states = ()
all_hidden_states = ()
all_attentions = ()
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
position_bias = None
encoder_decoder_position_bias = None
@ -725,7 +730,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
@ -739,6 +743,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_attentions=inputs["output_attentions"],
training=inputs["training"],
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
@ -747,10 +752,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2]
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if inputs["output_attentions"]:
all_attentions = all_attentions + (layer_outputs[3],)
@ -762,15 +770,30 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile
if cast_bool_to_primitive(inputs["use_cache"], self.use_cache) is True and self.is_decoder:
outputs = outputs + (present_key_value_states,)
if inputs["output_hidden_states"]:
outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if not inputs["return_dict"]:
outputs = (hidden_states,)
# need to check if is decoder here as well for special cases when using keras compile
if inputs["use_cache"] and self.is_decoder:
outputs = outputs + (present_key_value_states,)
if inputs["output_hidden_states"]:
outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if self.is_decoder:
return TFBaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
else:
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
####################################################
@ -1102,6 +1125,7 @@ class TFT5Model(TFT5PreTrainedModel):
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
@ -1119,38 +1143,25 @@ class TFT5Model(TFT5PreTrainedModel):
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
past = (
(inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
else None
)
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if not inputs["return_dict"]:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"]
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore.
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
decoder_outputs = decoder_outputs + (None,)
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs[0],
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=past,
decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=inputs["encoder_outputs"][0],
encoder_hidden_states=inputs["encoder_outputs"][1],
encoder_attentions=inputs["encoder_outputs"][2],
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
)
@ -1280,6 +1291,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
@ -1313,6 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
@ -1327,37 +1340,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = (
(inputs["encoder_outputs"], decoder_outputs[1])
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
else None
)
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if not inputs["return_dict"]:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
return ((loss,) + output) if loss is not None else output
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore.
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
decoder_outputs = decoder_outputs + (None,)
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
elif isinstance(inputs["encoder_outputs"], tuple):
last_hidden_state = inputs["encoder_outputs"][0]
hidden_states = None
attentions = None
idx = 0
if inputs["output_hidden_states"]:
idx += 1
hidden_states = inputs["encoder_outputs"][idx]
if inputs["output_attentions"]:
idx += 1
attentions = inputs["encoder_outputs"][idx]
inputs["encoder_outputs"] = TFBaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
attentions=attentions,
)
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=past,
decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=inputs["encoder_outputs"][0],
encoder_hidden_states=inputs["encoder_outputs"][1],
encoder_attentions=inputs["encoder_outputs"][2],
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
)
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
@ -1498,19 +1515,15 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
if not inputs["return_dict"]:
return encoder_outputs
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
encoder_outputs = encoder_outputs + (None,)
return TFBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1],
attentions=encoder_outputs[2],
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)

View File

@ -118,14 +118,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# inputs_embeds not supported
pass
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_saved_model_with_attentions_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -171,6 +171,11 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# A saved model is always executed in graph mode, since we merged the PR #8777
# the booleans in graph mode are always the ones in the config, then we update
# the use_cache property if it exists in order to have similar booleans with the inputs
if "use_cache" in class_inputs_dict:
config.use_cache = class_inputs_dict.pop("use_cache")
model = model_class(config)
num_out = len(model(class_inputs_dict))
model._saved_model_inputs_spec = None
@ -207,6 +212,11 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# A saved model is always executed in graph mode, since we merged the PR #8777
# the booleans in graph mode are always the ones in the config, then we update
# the use_cache property if it exists in order to have similar booleans with the inputs
if "use_cache" in class_inputs_dict:
config.use_cache = class_inputs_dict.pop("use_cache")
model = model_class(config)
num_out = len(model(class_inputs_dict))
model._saved_model_inputs_spec = None
@ -249,10 +259,11 @@ class TFModelTesterMixin:
if "T5" in main_layer_class.__name__:
# Take the same values than in TFT5ModelTester for this shared layer
shared = TFSharedEmbeddings(99, 32, name="shared")
config.use_cache = False
config.use_cache = inputs_dict.pop("use_cache", None)
main_layer = main_layer_class(config, embed_tokens=shared)
else:
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
}
@ -321,10 +332,13 @@ class TFModelTesterMixin:
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict(
(name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
)
pt_inputs_dict = {}
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
pt_inputs_dict[name] = key
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
@ -358,10 +372,13 @@ class TFModelTesterMixin:
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict(
(name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
)
pt_inputs_dict = {}
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
@ -574,13 +591,29 @@ class TFModelTesterMixin:
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
hidden_states = outputs[-1]
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
if model.config.is_encoder_decoder:
encoder_hidden_states = outputs.encoder_hidden_states
decoder_hidden_states = outputs.decoder_hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(encoder_hidden_states), expected_num_layers)
self.assertListEqual(
list(encoder_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
self.assertEqual(len(decoder_hidden_states), expected_num_layers)
self.assertListEqual(
list(decoder_hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
else:
hidden_states = outputs.hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size],
)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
@ -796,7 +829,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
input_ids = inputs_dict["input_ids"]
for model_class in self.all_generative_model_classes:
model = model_class(config)

View File

@ -133,8 +133,6 @@ class TFT5ModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_values = outputs
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
@ -142,7 +140,7 @@ class TFT5ModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
@ -164,7 +162,7 @@ class TFT5ModelTester:
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
_, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
@ -187,7 +185,7 @@ class TFT5ModelTester:
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0]
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
@ -208,8 +206,6 @@ class TFT5ModelTester:
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
@ -217,7 +213,7 @@ class TFT5ModelTester:
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
@ -236,7 +232,7 @@ class TFT5ModelTester:
"input_ids": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
"use_cache": tf.convert_to_tensor([False]),
"use_cache": False,
}
return config, inputs_dict
@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFT5Model.from_pretrained("t5-small")
self.assertIsNotNone(model)
@slow
def test_saved_model_with_attentions_output(self):
pass
@slow
def test_saved_model_with_hidden_states_output(self):
pass
class TFT5EncoderOnlyModelTester:
def __init__(
@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester:
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
def setUp(self):