Add test for proper TF input signatures (#24320)

* Add test for proper input signatures

* No more signature pruning

* Test the dummy inputs are valid too

* fine-tine -> fine-tune

* Fix indent in test_dataset_conversion
This commit is contained in:
Matt 2023-06-16 17:03:13 +01:00 committed by GitHub
parent bdfd57d1d1
commit 9138995025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 21 additions and 19 deletions

View File

@ -1122,8 +1122,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
dummies = {}
sig = self._prune_signature(self.input_signature)
for key, spec in sig.items():
for key, spec in self.input_signature.items():
# 2 is the most correct arbitrary size. I will not be taking questions
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
if spec.shape[0] is None:
@ -1159,7 +1158,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self._prune_signature(self.input_signature))
self._set_save_spec(self.input_signature)
self(self.dummy_inputs, training=False)
def __init__(self, config, *inputs, **kwargs):
@ -1300,11 +1299,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise NotImplementedError("Audio models need a manually defined input_signature")
return sig
def _prune_signature(self, signature):
"""Keeps only the keys of a given input signature that are valid for this model."""
model_inputs = list(inspect.signature(self.call).parameters)
return {key: val for key, val in signature.items() if key in model_inputs}
def serving_output(self, output):
"""
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
@ -2423,14 +2417,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
if signatures is None:
sig = self._prune_signature(self.input_signature)
serving_default = self.serving.get_concrete_function(sig)
if any(spec.dtype == tf.int32 for spec in sig.values()):
serving_default = self.serving.get_concrete_function(self.input_signature)
if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
int64_spec = {
key: tf.TensorSpec(
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
)
for key, spec in sig.items()
for key, spec in self.input_signature.items()
}
int64_serving = self.serving.get_concrete_function(int64_spec)
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}

View File

@ -1168,7 +1168,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU"
"to train/fine-tune this model, you need a GPU or a TPU"
)

View File

@ -98,7 +98,7 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU"
"to train/fine-tune this model, you need a GPU or a TPU"
)
padding = int((kernel_size - 1) / 2) * dilation

View File

@ -1202,7 +1202,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU"
"to train/fine-tune this model, you need a GPU or a TPU"
)
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):

View File

@ -1065,6 +1065,16 @@ class TFModelTesterMixin:
output_for_kw_input = model(**inputs_np)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_valid_input_signature_and_dummies(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
call_args = inspect.signature(model.call).parameters
for key in model.input_signature:
self.assertIn(key, call_args)
for key in model.dummy_inputs:
self.assertIn(key, call_args)
def test_resize_token_embeddings(self):
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
# tf.keras.layers.Embedding
@ -1700,7 +1710,7 @@ class TFModelTesterMixin:
for tensor in test_batch.values():
self.assertTrue(isinstance(tensor, tf.Tensor))
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
model(test_batch, training=False)
model(test_batch, training=False)
if "labels" in inspect.signature(model_class.call).parameters.keys():
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)

View File

@ -217,18 +217,17 @@ class TFCoreModelTesterMixin:
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
class_sig = model._prune_signature(model.input_signature)
num_out = len(model(class_inputs_dict))
for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in class_sig:
if key not in model.input_signature:
del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
if set(class_inputs_dict.keys()) != set(model.input_signature.keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname: