Fix T5 and BART for TF (#9063)
* Fix T5 for graphe compilation+execution * Fix BART * Fix import * Fix naming * fix attribute name * Oops * fix import * fix tests * fix tests * Update test * Add mising import * Address Patrick's comments * Style * Address Patrick's comment
This commit is contained in:
parent
a9c8bff724
commit
df3f4d2aef
|
@ -91,8 +91,6 @@ TensorFlow loss functions
|
|||
TensorFlow Helper Functions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: transformers.modeling_tf_utils.cast_bool_to_primitive
|
||||
|
||||
.. autofunction:: transformers.modeling_tf_utils.get_initializer
|
||||
|
||||
.. autofunction:: transformers.modeling_tf_utils.keras_serializable
|
||||
|
|
|
@ -51,7 +51,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
|
|||
) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
|
||||
tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
||||
tf_name = tf_name[1:] # Remove level zero
|
||||
# Some weights have a single name withtout "/" such as final_logits_bias in BART
|
||||
if len(tf_name) > 1:
|
||||
tf_name = tf_name[1:] # Remove level zero
|
||||
|
||||
# When should we transpose the weights
|
||||
transpose = bool(tf_name[-1] == "kernel" or "emb_projs" in tf_name or "out_projs" in tf_name)
|
||||
|
|
|
@ -354,7 +354,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||
if isinstance(v, allowed_types) or v is None:
|
||||
output[k] = v
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
|
||||
if isinstance(input_ids, (tuple, list)):
|
||||
for i, input in enumerate(input_ids):
|
||||
|
@ -372,7 +372,7 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||
output[parameter_names[i]] = input
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input)} is not allowed only tf.Tensor is accepted for {parameter_names[i]}."
|
||||
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
|
||||
)
|
||||
elif isinstance(input_ids, (dict, BatchEncoding)):
|
||||
if "inputs" in input_ids:
|
||||
|
@ -399,13 +399,13 @@ def input_processing(func, config, input_ids, **kwargs):
|
|||
)
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only tf.Tensor is accepted for {k}.")
|
||||
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
|
||||
else:
|
||||
if isinstance(input_ids, tf.Tensor) or input_ids is None:
|
||||
output[parameter_names[0]] = input_ids
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Data of type {type(input_ids)} is not allowed only tf.Tensor is accepted for {parameter_names[0]}."
|
||||
f"Data of type {type(input_ids)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
|
||||
)
|
||||
|
||||
for name in parameter_names:
|
||||
|
@ -1366,31 +1366,6 @@ def get_initializer(initializer_range: float = 0.02) -> tf.initializers.Truncate
|
|||
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
|
||||
|
||||
|
||||
def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
|
||||
"""
|
||||
Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
|
||||
cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
|
||||
|
||||
Args:
|
||||
bool_variable (:obj:`Union[tf.Tensor, bool]`):
|
||||
The variable to convert to a boolean.
|
||||
default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
|
||||
The default value to use in case the tensor has no numpy attribute.
|
||||
|
||||
Returns:
|
||||
:obj:`bool`: The converted value.
|
||||
"""
|
||||
# if bool variable is tensor and has numpy value
|
||||
if tf.is_tensor(bool_variable):
|
||||
if hasattr(bool_variable, "numpy"):
|
||||
return bool(bool_variable.numpy())
|
||||
elif default_tensor_to_true:
|
||||
return True
|
||||
|
||||
# else variable is bool
|
||||
return bool_variable
|
||||
|
||||
|
||||
class TFWrappedEmbeddings:
|
||||
"""
|
||||
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
|
||||
|
|
|
@ -41,7 +41,6 @@ from ...modeling_tf_utils import (
|
|||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -258,9 +257,11 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
|
||||
assert shape_list(x) == shape_list(
|
||||
residual
|
||||
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}"
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(x),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}",
|
||||
)
|
||||
x = self.dropout(x, training=training)
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
|
@ -295,9 +296,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.layerdrop = config.encoder_layerdrop
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
@ -328,7 +326,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
if config.add_final_layer_norm
|
||||
else None
|
||||
)
|
||||
self.return_dict = config.return_dict
|
||||
|
||||
def call(
|
||||
self,
|
||||
|
@ -355,10 +352,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
- **all_attentions** (List[tf.Tensor]): Attention weights for each layer.
|
||||
During training might not be of length n_layers because of layer dropout.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
return_dict = return_dict if return_dict is not None else self.return_dict
|
||||
|
||||
# check attention mask and invert
|
||||
if attention_mask is not None:
|
||||
assert (
|
||||
|
@ -546,9 +539,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.use_cache = config.use_cache
|
||||
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
|
||||
|
||||
def call(
|
||||
|
@ -565,14 +555,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||
return_dict=None,
|
||||
training=False,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
if use_cache:
|
||||
assert not training, "Training + use cache are incompatible"
|
||||
# check attention mask and invert
|
||||
use_cache = cast_bool_to_primitive(use_cache)
|
||||
if encoder_padding_mask is not None:
|
||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||
|
||||
|
@ -1046,7 +1029,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens):
|
||||
|
|
|
@ -32,12 +32,16 @@ from ...file_utils import (
|
|||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPast,
|
||||
TFSeq2SeqLMOutput,
|
||||
TFSeq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
cast_bool_to_primitive,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
@ -311,7 +315,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
# to cope with keras serialization
|
||||
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
|
||||
if self.is_decoder and use_cache:
|
||||
present_key_value_state = (key_states, value_states)
|
||||
else:
|
||||
present_key_value_state = None
|
||||
|
@ -594,6 +598,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
) -> Tuple:
|
||||
|
@ -610,6 +615,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
@ -713,10 +719,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
|
||||
assert inputs["head_mask"] is None, "Head mask not supported"
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
|
||||
present_key_value_states = ()
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
|
||||
all_hidden_states = () if inputs["output_hidden_states"] else None
|
||||
all_attentions = () if inputs["output_attentions"] else None
|
||||
position_bias = None
|
||||
encoder_decoder_position_bias = None
|
||||
|
||||
|
@ -725,7 +730,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
for i, (layer_module, past_key_value) in enumerate(zip(self.block, inputs["past_key_values"])):
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
|
@ -739,6 +743,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
output_attentions=inputs["output_attentions"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
|
@ -747,10 +752,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
|
||||
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
|
||||
position_bias = layer_outputs[2]
|
||||
|
||||
if self.is_decoder and inputs["encoder_hidden_states"] is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
||||
encoder_decoder_position_bias = layer_outputs[4 if inputs["output_attentions"] else 3]
|
||||
|
||||
# append next layer key value states
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
if present_key_value_state is not None and inputs["use_cache"] and self.is_decoder:
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_attentions = all_attentions + (layer_outputs[3],)
|
||||
|
@ -762,15 +770,30 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
# need to check if is decoder here as well for special cases when using keras compile
|
||||
if cast_bool_to_primitive(inputs["use_cache"], self.use_cache) is True and self.is_decoder:
|
||||
outputs = outputs + (present_key_value_states,)
|
||||
if inputs["output_hidden_states"]:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if inputs["output_attentions"]:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
if not inputs["return_dict"]:
|
||||
outputs = (hidden_states,)
|
||||
# need to check if is decoder here as well for special cases when using keras compile
|
||||
if inputs["use_cache"] and self.is_decoder:
|
||||
outputs = outputs + (present_key_value_states,)
|
||||
if inputs["output_hidden_states"]:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if inputs["output_attentions"]:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
if self.is_decoder:
|
||||
return TFBaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=present_key_value_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
else:
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
####################################################
|
||||
|
@ -1102,6 +1125,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
use_cache=False,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
|
@ -1119,38 +1143,25 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
past = (
|
||||
(inputs["encoder_outputs"], decoder_outputs[1])
|
||||
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
|
||||
else None
|
||||
)
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
return decoder_outputs + inputs["encoder_outputs"]
|
||||
|
||||
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
||||
# TF refuses to compile anymore.
|
||||
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
||||
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
|
||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
|
||||
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
|
||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
|
||||
decoder_outputs = decoder_outputs + (None,)
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs[0],
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=past,
|
||||
decoder_hidden_states=decoder_outputs[2],
|
||||
decoder_attentions=decoder_outputs[3],
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"][0],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][1],
|
||||
encoder_attentions=inputs["encoder_outputs"][2],
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1280,6 +1291,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
head_mask=inputs["head_mask"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
|
@ -1313,6 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
use_cache=inputs["use_cache"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
|
@ -1327,37 +1340,41 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
|||
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
|
||||
|
||||
past = (
|
||||
(inputs["encoder_outputs"], decoder_outputs[1])
|
||||
if cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache)
|
||||
else None
|
||||
)
|
||||
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
|
||||
if not inputs["return_dict"]:
|
||||
if past is not None:
|
||||
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
|
||||
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
|
||||
# TF refuses to compile anymore.
|
||||
if not cast_bool_to_primitive(inputs["use_cache"], self.config.use_cache):
|
||||
decoder_outputs = decoder_outputs[:1] + (None,) + decoder_outputs[1:]
|
||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"][:1] + (None,) + inputs["encoder_outputs"][1:]
|
||||
decoder_outputs = decoder_outputs[:2] + (None,) + decoder_outputs[2:]
|
||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
||||
inputs["encoder_outputs"] = inputs["encoder_outputs"] + (None,)
|
||||
decoder_outputs = decoder_outputs + (None,)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
|
||||
elif isinstance(inputs["encoder_outputs"], tuple):
|
||||
last_hidden_state = inputs["encoder_outputs"][0]
|
||||
hidden_states = None
|
||||
attentions = None
|
||||
idx = 0
|
||||
if inputs["output_hidden_states"]:
|
||||
idx += 1
|
||||
hidden_states = inputs["encoder_outputs"][idx]
|
||||
if inputs["output_attentions"]:
|
||||
idx += 1
|
||||
attentions = inputs["encoder_outputs"][idx]
|
||||
|
||||
inputs["encoder_outputs"] = TFBaseModelOutput(
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
attentions=attentions,
|
||||
)
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=past,
|
||||
decoder_hidden_states=decoder_outputs[2],
|
||||
decoder_attentions=decoder_outputs[3],
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"][0],
|
||||
encoder_hidden_states=inputs["encoder_outputs"][1],
|
||||
encoder_attentions=inputs["encoder_outputs"][2],
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
|
||||
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
|
||||
encoder_attentions=inputs["encoder_outputs"].attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
|
||||
|
@ -1498,19 +1515,15 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
|||
use_cache=False,
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return encoder_outputs
|
||||
|
||||
if not cast_bool_to_primitive(inputs["output_hidden_states"], self.config.output_hidden_states):
|
||||
encoder_outputs = encoder_outputs[:1] + (None,) + encoder_outputs[1:]
|
||||
if not cast_bool_to_primitive(inputs["output_attentions"], self.config.output_attentions):
|
||||
encoder_outputs = encoder_outputs + (None,)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1],
|
||||
attentions=encoder_outputs[2],
|
||||
last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
|
|
@ -118,14 +118,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
|
@ -171,6 +171,11 @@ class TFModelTesterMixin:
|
|||
|
||||
for model_class in self.all_model_classes:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
# A saved model is always executed in graph mode, since we merged the PR #8777
|
||||
# the booleans in graph mode are always the ones in the config, then we update
|
||||
# the use_cache property if it exists in order to have similar booleans with the inputs
|
||||
if "use_cache" in class_inputs_dict:
|
||||
config.use_cache = class_inputs_dict.pop("use_cache")
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
|
@ -207,6 +212,11 @@ class TFModelTesterMixin:
|
|||
|
||||
for model_class in self.all_model_classes:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
# A saved model is always executed in graph mode, since we merged the PR #8777
|
||||
# the booleans in graph mode are always the ones in the config, then we update
|
||||
# the use_cache property if it exists in order to have similar booleans with the inputs
|
||||
if "use_cache" in class_inputs_dict:
|
||||
config.use_cache = class_inputs_dict.pop("use_cache")
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
|
@ -249,10 +259,11 @@ class TFModelTesterMixin:
|
|||
if "T5" in main_layer_class.__name__:
|
||||
# Take the same values than in TFT5ModelTester for this shared layer
|
||||
shared = TFSharedEmbeddings(99, 32, name="shared")
|
||||
config.use_cache = False
|
||||
config.use_cache = inputs_dict.pop("use_cache", None)
|
||||
main_layer = main_layer_class(config, embed_tokens=shared)
|
||||
else:
|
||||
main_layer = main_layer_class(config)
|
||||
|
||||
symbolic_inputs = {
|
||||
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
||||
}
|
||||
|
@ -321,10 +332,13 @@ class TFModelTesterMixin:
|
|||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict(
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
||||
)
|
||||
pt_inputs_dict = {}
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||
if type(key) == bool:
|
||||
pt_inputs_dict[name] = key
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
@ -358,10 +372,13 @@ class TFModelTesterMixin:
|
|||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
pt_inputs_dict = dict(
|
||||
(name, torch.from_numpy(key.numpy()).to(torch.long))
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
|
||||
)
|
||||
pt_inputs_dict = {}
|
||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
||||
if type(key) == bool:
|
||||
key = np.array(key, dtype=bool)
|
||||
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
@ -574,13 +591,29 @@ class TFModelTesterMixin:
|
|||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
hidden_states = outputs[-1]
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_hidden_states = outputs.encoder_hidden_states
|
||||
decoder_hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(encoder_hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(encoder_hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
self.assertEqual(len(decoder_hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.hidden_states
|
||||
self.assertEqual(config.output_attentions, False)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
|
@ -796,7 +829,7 @@ class TFModelTesterMixin:
|
|||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
|
|
@ -133,8 +133,6 @@ class TFT5ModelTester:
|
|||
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
||||
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
||||
|
||||
output, past_key_values = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
||||
|
@ -142,7 +140,7 @@ class TFT5ModelTester:
|
|||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||
|
@ -164,7 +162,7 @@ class TFT5ModelTester:
|
|||
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
|
||||
|
||||
# first forward pass
|
||||
_, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
|
@ -187,7 +185,7 @@ class TFT5ModelTester:
|
|||
|
||||
# get two different outputs
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
|
||||
|
@ -208,8 +206,6 @@ class TFT5ModelTester:
|
|||
# first forward pass
|
||||
outputs = model(input_ids, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
|
||||
|
@ -217,7 +213,7 @@ class TFT5ModelTester:
|
|||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=past_key_values)[0]
|
||||
output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
|
||||
|
||||
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
|
||||
|
||||
|
@ -236,7 +232,7 @@ class TFT5ModelTester:
|
|||
"input_ids": input_ids,
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_attention_mask": input_mask,
|
||||
"use_cache": tf.convert_to_tensor([False]),
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
@ -298,14 +294,6 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFT5Model.from_pretrained("t5-small")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
|
||||
class TFT5EncoderOnlyModelTester:
|
||||
def __init__(
|
||||
|
@ -411,6 +399,7 @@ class TFT5EncoderOnlyModelTester:
|
|||
|
||||
|
||||
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue