Update serving code to enable `saved_model=True` (#18153)

* Add serving_output and serving methods to some vision models

* Add serving outputs for DeiT

* Don't convert hidden states - differing shapes

* Make saveable

* Fix up

* Make swin saveable

* Add in tests

* Fix funnel tests (can't convert to tensor)

* Fix numpy call

* Tidy up a bit

* Add in hidden states - resnet

* Remove numpy

* Fix failing tests - tensor shape and skipping tests

* Remove duplicated function

* PR comments - formatting and var names

* PR comments
Add suggestions made by Joao Gante:
* Use tf.shape instead of shape_list
* Use @tooslow decorator on tests
* Simplify some of the logic

* PR comments
Address Yih-Dar Sheih comments - making tensor names consistent and make types float

* Types consistent with docs; disable test on swin (slow)

* CI trigger

* Change input_features to float32

* Add serving_output for segformer

* Fixup

Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
This commit is contained in:
amyeroberts 2022-07-22 18:05:38 +01:00 committed by GitHub
parent 07505358ba
commit 8e8384663d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 471 additions and 238 deletions

View File

@ -383,7 +383,8 @@ class TFConvNextPreTrainedModel(TFPreTrainedModel):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)
CONVNEXT_START_DOCSTRING = r"""
@ -492,6 +493,14 @@ class TFConvNextModel(TFConvNextPreTrainedModel):
hidden_states=outputs.hidden_states,
)
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
)
@add_start_docstrings(
"""
@ -584,3 +593,7 @@ class TFConvNextForImageClassification(TFConvNextPreTrainedModel, TFSequenceClas
logits=logits,
hidden_states=outputs.hidden_states,
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=output.hidden_states)

View File

@ -801,8 +801,8 @@ class TFData2VecVisionPreTrainedModel(TFPreTrainedModel):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)
DATA2VEC_VISION_START_DOCSTRING = r"""
@ -910,6 +910,17 @@ class TFData2VecVisionModel(TFData2VecVisionPreTrainedModel):
return outputs
def serving_output(self, output: TFData2VecVisionModelOutputWithPooling) -> TFData2VecVisionModelOutputWithPooling:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFData2VecVisionModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hidden_states,
attentions=attentions,
)
@add_start_docstrings(
"""
@ -983,6 +994,12 @@ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TF
attentions=outputs.attentions,
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
class TFData2VecVisionConvModule(tf.keras.layers.Layer):
"""
@ -1443,3 +1460,9 @@ class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
def serving_output(self, output: TFSemanticSegmenterOutput) -> TFSemanticSegmenterOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSemanticSegmenterOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)

View File

@ -193,7 +193,7 @@ class TFDeiTPatchEmbeddings(tf.keras.layers.Layer):
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
if tf.executing_eagerly() and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
@ -680,14 +680,14 @@ class TFDeiTModel(TFDeiTPreTrainedModel):
return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
hidden_states=hidden_states,
attentions=attentions,
)
@ -864,6 +864,12 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
attentions=outputs.attentions,
)
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
@add_start_docstrings(
"""
@ -961,6 +967,12 @@ class TFDeiTForImageClassification(TFDeiTPreTrainedModel, TFSequenceClassificati
attentions=outputs.attentions,
)
def serving_output(self, output: TFImageClassifierOutput) -> TFImageClassifierOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFImageClassifierOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
@add_start_docstrings(
"""
@ -1041,3 +1053,17 @@ class TFDeiTForImageClassificationWithTeacher(TFDeiTPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(
self, output: TFDeiTForImageClassificationWithTeacherOutput
) -> TFDeiTForImageClassificationWithTeacherOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFDeiTForImageClassificationWithTeacherOutput(
logits=output.logits,
cls_logits=output.cls_logits,
distillation_logits=output.distillation_logits,
hidden_states=hidden_states,
attentions=attentions,
)

View File

@ -1127,12 +1127,14 @@ class TFFunnelBaseModel(TFFunnelPreTrainedModel):
training=training,
)
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
@add_start_docstrings(
@ -1175,12 +1177,14 @@ class TFFunnelModel(TFFunnelPreTrainedModel):
training=training,
)
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
@add_start_docstrings(
@ -1249,10 +1253,11 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
)
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFFunnelForPreTrainingOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFFunnelForPreTrainingOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
@add_start_docstrings("""Funnel Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
@ -1322,12 +1327,10 @@ class TFFunnelForMaskedLM(TFFunnelPreTrainedModel, TFMaskedLanguageModelingLoss)
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFMaskedLMOutput(logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions)
@add_start_docstrings(
@ -1398,12 +1401,12 @@ class TFFunnelForSequenceClassification(TFFunnelPreTrainedModel, TFSequenceClass
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFSequenceClassifierOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
@add_start_docstrings(
@ -1503,9 +1506,9 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
@ -1514,12 +1517,12 @@ class TFFunnelForMultipleChoice(TFFunnelPreTrainedModel, TFMultipleChoiceLoss):
return self.serving_output(output=output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFMultipleChoiceModelOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
@add_start_docstrings(
@ -1592,12 +1595,12 @@ class TFFunnelForTokenClassification(TFFunnelPreTrainedModel, TFTokenClassificat
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFTokenClassifierOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
@add_start_docstrings(
@ -1683,11 +1686,12 @@ class TFFunnelForQuestionAnswering(TFFunnelPreTrainedModel, TFQuestionAnsweringL
attentions=outputs.attentions,
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of
# different dimensions
return TFQuestionAnsweringModelOutput(
start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns
start_logits=output.start_logits,
end_logits=output.end_logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
)

View File

@ -227,12 +227,13 @@ def _compute_mask_indices(
f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
num_masked_spans = max(num_masked_spans, min_masks)
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
num_masked_spans = tf.squeeze(num_masked_spans)
# SpecAugment mask to fill
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
@ -256,7 +257,7 @@ def _compute_mask_indices(
# scatter indices to mask
spec_aug_mask = _scatter_values_on_batch_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
)
return spec_aug_mask
@ -1319,7 +1320,15 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
"to train/fine-tine this model, you need a GPU or a TPU"
)
@tf.function
@tf.function(
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
output = self.call(input_values=inputs, training=False)
@ -1511,10 +1520,11 @@ class TFHubertModel(TFHubertPreTrainedModel):
return outputs
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state, hidden_states=hidden_states, attentions=attentions
)
@add_start_docstrings(
@ -1685,6 +1695,6 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
)
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)

View File

@ -371,7 +371,8 @@ class TFRegNetPreTrainedModel(TFPreTrainedModel):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)
REGNET_START_DOCSTRING = r"""
@ -444,6 +445,16 @@ class TFRegNetModel(TFRegNetPreTrainedModel):
hidden_states=outputs.hidden_states,
)
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndNoAttention
) -> TFBaseModelOutputWithPoolingAndNoAttention:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
)
@add_start_docstrings(
"""
@ -506,3 +517,7 @@ class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassifi
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=output.hidden_states)

View File

@ -263,10 +263,7 @@ class TFResNetEncoder(tf.keras.layers.Layer):
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
class TFResNetPreTrainedModel(TFPreTrainedModel):
@ -288,6 +285,17 @@ class TFResNetPreTrainedModel(TFPreTrainedModel):
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
RESNET_START_DOCSTRING = r"""
This model is a TensorFlow
@ -413,6 +421,16 @@ class TFResNetModel(TFResNetPreTrainedModel):
)
return resnet_outputs
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndNoAttention
) -> TFBaseModelOutputWithPoolingAndNoAttention:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
)
@add_start_docstrings(
"""
@ -477,3 +495,7 @@ class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassifi
return (loss,) + output if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFImageClassifierOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states)

View File

@ -544,7 +544,9 @@ class TFSegformerPreTrainedModel(TFPreTrainedModel):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)
SEGFORMER_START_DOCSTRING = r"""
@ -628,6 +630,14 @@ class TFSegformerModel(TFSegformerPreTrainedModel):
)
return outputs
def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput:
# hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutput(
last_hidden_state=output.last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)
@add_start_docstrings(
"""
@ -692,6 +702,12 @@ class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceCl
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
# hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSequenceClassifierOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)
class TFSegformerMLP(tf.keras.layers.Layer):
"""
@ -876,3 +892,9 @@ class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
def serving_output(self, output: TFSemanticSegmenterOutput) -> TFSemanticSegmenterOutput:
# hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSemanticSegmenterOutput(
logits=output.logits, hidden_states=output.hidden_states, attentions=output.attentions
)

View File

@ -143,7 +143,8 @@ class TFConv1dSubsampler(tf.keras.layers.Layer):
]
def call(self, input_features: tf.Tensor) -> tf.Tensor:
hidden_states = tf.identity(input_features) # TF Conv1D assumes Batch x Time x Channels, same as the input
# TF Conv1D assumes Batch x Time x Channels, same as the input
hidden_states = tf.cast(input_features, tf.float32)
for i, conv in enumerate(self.conv_layers):
# equivalent to `padding=k // 2` on PT's `nn.Conv1d`
pad_len = self.kernel_sizes[i] // 2
@ -187,23 +188,20 @@ class TFSpeech2TextSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
# zero pad
emb = tf.concat([emb, tf.zeros(num_embeddings, 1)], axis=1)
if padding_idx is not None:
emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, emb.shape[1])), emb[padding_idx + 1 :, :]], axis=0)
emb = tf.concat([emb[:padding_idx, :], tf.zeros((1, tf.shape(emb)[1])), emb[padding_idx + 1 :, :]], axis=0)
return emb
def _resize_embeddings(self):
"""Recreates (and effectivelly resizes) the sinusoidal embeddings"""
self.embeddings = self.add_weight(
name="weights", # name also used in PT
shape=self.embedding_weights.shape,
)
self.embeddings.assign(self.embedding_weights)
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self._resize_embeddings()
self.embeddings = self.add_weight(
name="weights", # name also used in PT
shape=tf.shape(self.embedding_weights),
trainable=False,
)
self.embeddings.assign(self.embedding_weights)
super().build(input_shape)
def call(self, input_ids: tf.Tensor, past_key_values_length: int = 0) -> tf.Tensor:
@ -215,7 +213,7 @@ class TFSpeech2TextSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
max_pos = self.padding_idx + 1 + seq_len
if max_pos > shape_list(self.embeddings)[0]:
self.embedding_weights = self._get_embedding(max_pos + self.offset, self.embedding_dim, self.padding_idx)
self._resize_embeddings()
self.embeddings.assign(self.embedding_weights)
return tf.reshape(tf.gather(self.embeddings, tf.reshape(position_ids, (-1,)), axis=0), (bsz, seq_len, -1))
@staticmethod
@ -608,7 +606,7 @@ class TFSpeech2TextPreTrainedModel(TFPreTrainedModel):
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"input_features": tf.TensorSpec((None, None, None), tf.float32, name="input_features"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
"decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"),
@ -791,7 +789,6 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
),
axis=-1,
)
attention_mask = tf.scatter_nd(indices=indices, updates=tf.ones(bsz), shape=[bsz, feature_vector_length])
attention_mask = tf.cast(tf.reverse(tf.math.cumsum(tf.reverse(attention_mask, [-1]), -1), [-1]), tf.int64)
return attention_mask
@ -845,10 +842,10 @@ class TFSpeech2TextEncoder(tf.keras.layers.Layer):
# subsample attention mask if necessary
if attention_mask is not None:
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
attention_mask = self._get_feature_vector_attention_mask(tf.shape(inputs_embeds)[1], attention_mask)
padding_mask = tf.cast(tf.math.not_equal(attention_mask, 1), tf.int64)
else:
padding_mask = tf.zeros(inputs_embeds.shape[:-1], dtype=tf.int64)
padding_mask = tf.zeros(tf.shape(inputs_embeds)[:-1], dtype=tf.int64)
embed_pos = self.embed_positions(padding_mask)
@ -942,22 +939,6 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@unpack_inputs
def call(
self,
@ -1053,9 +1034,16 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
else:
inputs_embeds = inputs_embeds
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
)
if attention_mask is not None:
combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
@ -1100,7 +1088,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=attention_mask,
attention_mask=combined_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=head_mask[idx] if head_mask is not None else None,
@ -1203,7 +1191,7 @@ class TFSpeech2TextMainLayer(tf.keras.layers.Layer):
# downsample encoder attention mask
if attention_mask is not None:
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
encoder_outputs[0].shape[1], attention_mask
tf.shape(encoder_outputs[0])[1], attention_mask
)
else:
encoder_attention_mask = None
@ -1465,8 +1453,8 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
return TFSeq2SeqModelOutput(
last_hidden_state=output.last_hidden_state,
return TFSeq2SeqLMOutput(
logits=output.logits,
past_key_values=pkv,
decoder_hidden_states=dec_hs,
decoder_attentions=dec_attns,

View File

@ -227,7 +227,7 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
Merges windows to produce higher resolution features.
"""
x = shape_list(windows)[0]
y = tf.cast(height * width / window_size / window_size, tf.int32)
y = tf.cast(height * width / (window_size * window_size), tf.int32)
batch_size = int(x / y)
windows = tf.reshape(
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
@ -541,7 +541,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
attention_scores = attention_scores + tf.expand_dims(relative_position_bias, 0)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
# Apply the attention mask is (precomputed for all layers in SwinModel call() function)
mask_shape = shape_list(attention_mask)[0]
attention_scores = tf.reshape(
attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
@ -647,10 +647,10 @@ class TFSwinLayer(tf.keras.layers.Layer):
) -> None:
super().__init__(**kwargs)
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
min_res = tf.reduce_min(input_resolution)
self.window_size = min_res if min_res <= config.window_size else config.window_size
self.shift_size = 0 if min_res <= self.window_size else shift_size
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_before"
@ -659,7 +659,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
self.drop_path = (
TFSwinDropPath(config.drop_path_rate, name="drop_path")
if config.drop_path_rate > 0.0
else tf.identity(name="drop_path")
else tf.keras.layers.Activation("linear", name="drop_path")
)
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
@ -667,56 +667,38 @@ class TFSwinLayer(tf.keras.layers.Layer):
self.intermediate = TFSwinIntermediate(config, dim, name="intermediate")
self.swin_output = TFSwinOutput(config, dim, name="output")
def set_shift_and_window_size(self, input_resolution: Tuple[int, int]) -> None:
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height: int, width: int) -> Optional[tf.Tensor]:
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = tf.zeros((height, width))
height_slices = (
(0, -self.window_size),
(-self.window_size, -self.shift_size),
(-self.shift_size, -1),
)
width_slices = (
(0, -self.window_size),
(-self.window_size, -self.shift_size),
(-self.shift_size, -1),
)
def get_attn_mask(self, height: int, width: int, window_size: int, shift_size: int) -> Optional[tf.Tensor]:
img_mask = tf.zeros((height, width))
height_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1))
width_slices = ((0, -window_size), (-window_size, -shift_size), (-shift_size, -1))
# calculate attention mask for SW-MSA
if shift_size > 0:
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
indices = [
[i, j]
for i in range(height_slice[0] % height, height_slice[1] % height + 1)
for j in range(width_slice[0] % width, width_slice[1] % width + 1)
]
if indices:
height_inds = tf.range(height_slice[0] % height, height_slice[1] % height + 1)
width_inds = tf.range(width_slice[0] % width, width_slice[1] % width + 1)
indices = tf.reshape(tf.stack(tf.meshgrid(height_inds, width_inds), axis=-1), (-1, 2))
if len(indices) >= 1:
updates = tf.ones((len(indices),), dtype=img_mask.dtype) * count
img_mask = tf.tensor_scatter_nd_update(img_mask, indices, updates)
count += 1
img_mask = tf.expand_dims(img_mask, -1)
img_mask = tf.expand_dims(img_mask, 0)
img_mask = tf.expand_dims(img_mask, -1)
img_mask = tf.expand_dims(img_mask, 0)
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size))
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
else:
attn_mask = None
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size))
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
return attn_mask
def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]:
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = tf.constant([[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]])
pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
hidden_states = tf.pad(hidden_states, pad_values)
pad_values = tf.reshape(pad_values, (-1,))
return hidden_states, pad_values
@ -729,7 +711,11 @@ class TFSwinLayer(tf.keras.layers.Layer):
output_attentions: bool = False,
training: bool = False,
) -> tf.Tensor:
self.set_shift_and_window_size(input_dimensions)
# if window size is larger than input resolution, we don't partition windows
min_res = tf.reduce_min(input_dimensions)
shift_size = 0 if min_res <= self.window_size else self.shift_size
window_size = min_res if min_res <= self.window_size else self.window_size
height, width = input_dimensions
batch_size, _, channels = shape_list(hidden_states)
shortcut = hidden_states
@ -741,15 +727,17 @@ class TFSwinLayer(tf.keras.layers.Layer):
_, height_pad, width_pad, _ = shape_list(hidden_states)
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
if shift_size > 0:
shifted_hidden_states = tf.roll(hidden_states, shift=(-shift_size, -shift_size), axis=(1, 2))
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = tf.reshape(hidden_states_windows, (-1, self.window_size * self.window_size, channels))
attn_mask = self.get_attn_mask(height_pad, width_pad)
hidden_states_windows = window_partition(shifted_hidden_states, window_size)
hidden_states_windows = tf.reshape(hidden_states_windows, (-1, window_size * window_size, channels))
attn_mask = self.get_attn_mask(
height=height_pad, width=width_pad, window_size=window_size, shift_size=shift_size
)
attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions, training=training
@ -757,12 +745,12 @@ class TFSwinLayer(tf.keras.layers.Layer):
attention_output = attention_outputs[0]
attention_windows = tf.reshape(attention_output, (-1, self.window_size, self.window_size, channels))
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
attention_windows = tf.reshape(attention_output, (-1, window_size, window_size, channels))
shifted_windows = window_reverse(attention_windows, window_size, height_pad, width_pad)
# reverse cyclic shift
if self.shift_size > 0:
attention_windows = tf.roll(shifted_windows, shift=(self.shift_size, self.shift_size), axis=(1, 2))
if shift_size > 0:
attention_windows = tf.roll(shifted_windows, shift=(shift_size, shift_size), axis=(1, 2))
else:
attention_windows = shifted_windows
@ -961,6 +949,17 @@ class TFSwinPreTrainedModel(TFPreTrainedModel):
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
SWIN_START_DOCSTRING = r"""
This model is a Tensorflow
@ -1223,6 +1222,16 @@ class TFSwinModel(TFSwinPreTrainedModel):
return swin_outputs
def serving_output(self, output: TFSwinModelOutput) -> TFSwinModelOutput:
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSwinModelOutput(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
attentions=output.attentions,
reshaped_hidden_states=output.reshaped_hidden_states,
)
class TFSwinPixelShuffle(tf.keras.layers.Layer):
"""TF layer implementation of torch.nn.PixelShuffle"""
@ -1377,6 +1386,15 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput:
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSwinMaskedImageModelingOutput(
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
reshaped_hidden_states=output.reshaped_hidden_states,
)
@add_start_docstrings(
"""
@ -1396,7 +1414,7 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati
self.classifier = (
tf.keras.layers.Dense(config.num_labels, name="classifier")
if config.num_labels > 0
else tf.identity(name="classifier")
else tf.keras.layers.Activation("linear", name="classifier")
)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@ -1452,3 +1470,12 @@ class TFSwinForImageClassification(TFSwinPreTrainedModel, TFSequenceClassificati
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
def serving_output(self, output: TFSwinImageClassifierOutput) -> TFSwinImageClassifierOutput:
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSwinImageClassifierOutput(
logits=output.logits,
hidden_states=output.hidden_states,
attentions=output.attentions,
reshaped_hidden_states=output.reshaped_hidden_states,
)

View File

@ -862,6 +862,19 @@ class TFTapasPreTrainedModel(TFPreTrainedModel):
config_class = TapasConfig
base_model_prefix = "tapas"
@tf.function(
input_signature=[
{
"input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
"attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
output = self.call(inputs)
return self.serving_output(output)
TAPAS_START_DOCSTRING = r"""
@ -1021,14 +1034,14 @@ class TFTapasModel(TFTapasPreTrainedModel):
return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
hidden_states=hidden_states,
attentions=attentions,
)
@ -1128,10 +1141,10 @@ class TFTapasForMaskedLM(TFTapasPreTrainedModel, TFMaskedLanguageModelingLoss):
)
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
class TFTapasComputeTokenLogits(tf.keras.layers.Layer):
@ -1557,11 +1570,14 @@ class TFTapasForQuestionAnswering(TFTapasPreTrainedModel):
)
def serving_output(self, output: TFTableQuestionAnsweringOutput) -> TFTableQuestionAnsweringOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFTableQuestionAnsweringOutput(
logits=output.logits, logits_aggregation=output.logits_aggregation, hidden_states=hs, attentions=attns
logits=output.logits,
logits_aggregation=output.logits_aggregation,
hidden_states=hidden_states,
attentions=attentions,
)
@ -1667,10 +1683,10 @@ class TFTapasForSequenceClassification(TFTapasPreTrainedModel, TFSequenceClassif
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)
""" TAPAS utilities."""

View File

@ -722,8 +722,8 @@ class TFViTMAEPreTrainedModel(TFPreTrainedModel):
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
output = self.call(inputs)
return self.serving_output(output)
VIT_MAE_START_DOCSTRING = r"""
@ -842,6 +842,18 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
return outputs
def serving_output(self, output: TFViTMAEModelOutput) -> TFViTMAEModelOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFViTMAEModelOutput(
last_hidden_state=output.last_hidden_state,
mask=output.mask,
ids_restore=output.ids_restore,
hidden_states=hidden_states,
attentions=attentions,
)
class TFViTMAEDecoder(tf.keras.layers.Layer):
def __init__(self, config, num_patches, **kwargs):
@ -1143,3 +1155,15 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFViTMAEForPreTrainingOutput) -> TFViTMAEForPreTrainingOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFViTMAEForPreTrainingOutput(
logits=output.logits,
mask=output.mask,
ids_restore=output.ids_restore,
hidden_states=hidden_states,
attentions=attentions,
)

View File

@ -268,12 +268,13 @@ def _compute_mask_indices(
f" `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
num_masked_spans = max(num_masked_spans, min_masks)
num_masked_spans = mask_prob * sequence_length / mask_length + tf.random.uniform((1,))
num_masked_spans = tf.maximum(num_masked_spans, min_masks)
num_masked_spans = tf.cast(num_masked_spans, tf.int32)
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
num_masked_spans = tf.squeeze(num_masked_spans)
# SpecAugment mask to fill
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
@ -297,7 +298,7 @@ def _compute_mask_indices(
# scatter indices to mask
spec_aug_mask = _scatter_values_on_batch_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
)
return spec_aug_mask
@ -1352,7 +1353,15 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
"to train/fine-tine this model, you need a GPU or a TPU"
)
@tf.function
@tf.function(
input_signature=[
{
"input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"),
"attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
"token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
}
]
)
def serving(self, inputs):
output = self.call(input_values=inputs, training=False)
@ -1544,14 +1553,14 @@ class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
return outputs
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFWav2Vec2BaseModelOutput(
last_hidden_state=output.last_hidden_state,
extract_features=output.extract_features,
hidden_states=hs,
attentions=attns,
hidden_states=hidden_states,
attentions=attentions,
)
@ -1726,6 +1735,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
)
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions)

View File

@ -18,7 +18,7 @@ import unittest
import numpy as np
from transformers import BartConfig, BartTokenizer, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -293,8 +293,8 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
models_equal = False
self.assertTrue(models_equal)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

View File

@ -17,7 +17,7 @@
import unittest
from transformers import BlenderbotConfig, BlenderbotTokenizer, is_tf_available
from transformers.testing_utils import require_tf, require_tokenizers, slow
from transformers.testing_utils import require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -213,8 +213,8 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):

View File

@ -17,7 +17,7 @@
import unittest
from transformers import BlenderbotSmallConfig, BlenderbotSmallTokenizer, is_tf_available
from transformers.testing_utils import require_tf, require_tokenizers, slow
from transformers.testing_utils import require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -278,8 +278,8 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

View File

@ -606,11 +606,21 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFCLIPModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
@slow
def test_saved_model_creation(self):
pass
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
@slow
def test_saved_model_creation_extended(self):
pass
@unittest.skip(reason="`saved_model` doesn't work with nested outputs so no preparation happens.")
@slow
def test_prepare_serving_output(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -17,7 +17,7 @@
import unittest
from transformers import FunnelConfig, is_tf_available
from transformers.testing_utils import require_tf
from transformers.testing_utils import require_tf, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -371,8 +371,8 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_compile_tf_model(self):
@ -407,6 +407,6 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

View File

@ -17,7 +17,7 @@
import unittest
from transformers import LEDConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
@ -365,8 +365,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO JP: Make LED XLA compliant
pass
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_generate_with_headmasking(self):

View File

@ -17,7 +17,7 @@
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -326,8 +326,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_xla_mode(self):

View File

@ -20,7 +20,7 @@ import unittest
import numpy as np
from transformers import LxmertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -600,8 +600,8 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
@slow

View File

@ -19,7 +19,7 @@ import unittest
import warnings
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -246,8 +246,8 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):

View File

@ -17,7 +17,7 @@ import tempfile
import unittest
from transformers import AutoTokenizer, MBartConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -281,8 +281,8 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

View File

@ -17,7 +17,7 @@
import unittest
from transformers import MobileBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
@ -306,8 +306,8 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
@slow

View File

@ -18,7 +18,7 @@ import unittest
import numpy as np
from transformers import OPTConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, slow
from transformers.testing_utils import require_sentencepiece, require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
@ -227,8 +227,8 @@ class TFOPTModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

View File

@ -17,7 +17,7 @@ import tempfile
import unittest
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -244,8 +244,8 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):

View File

@ -21,7 +21,7 @@ import unittest
import numpy as np
from transformers import SwinConfig
from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple
from transformers.testing_utils import require_tf, require_vision, slow, to_2tuple, tooslow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -225,6 +225,10 @@ class TFSwinModelTest(TFModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@tooslow
def test_saved_model_creation(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -16,7 +16,7 @@
import unittest
from transformers import T5Config, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow, tooslow
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -305,8 +305,8 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
@tooslow
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
@slow

View File

@ -205,6 +205,47 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs)
@slow
def test_saved_model_creation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = False
config.output_attentions = False
if hasattr(config, "use_cache"):
config.use_cache = False
model_class = self.all_model_classes[0]
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
model(class_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir))
def test_prepare_serving_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = self.has_attentions
for model_class in self.all_model_classes:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(inputs)
serving_outputs = model.serving_output(outputs)
for k, v in serving_outputs.items():
# Check that we have one of three possible outputs: None, tuple of tensors or a tensor
if isinstance(v, tuple):
self.assertTrue(all(isinstance(elem, tf.Tensor) for elem in v))
elif v is not None:
self.assertIsInstance(v, tf.Tensor)
else:
self.assertIsNone(v)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -201,27 +201,6 @@ class TFCoreModelTesterMixin:
val_loss = history.history["val_loss"][0]
self.assertTrue(not isnan(val_loss))
@slow
def test_saved_model_creation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = False
config.output_attentions = False
if hasattr(config, "use_cache"):
config.use_cache = False
model_class = self.all_model_classes[0]
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
model(class_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir))
@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()