Add test for proper TF input signatures (#24320)

* Add test for proper input signatures

* No more signature pruning

* Test the dummy inputs are valid too

* fine-tine -> fine-tune

* Fix indent in test_dataset_conversion
This commit is contained in:
Matt 2023-06-16 17:03:13 +01:00 committed by GitHub
parent bdfd57d1d1
commit 9138995025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 21 additions and 19 deletions

View File

@ -1122,8 +1122,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
`Dict[str, tf.Tensor]`: The dummy inputs. `Dict[str, tf.Tensor]`: The dummy inputs.
""" """
dummies = {} dummies = {}
sig = self._prune_signature(self.input_signature) for key, spec in self.input_signature.items():
for key, spec in sig.items():
# 2 is the most correct arbitrary size. I will not be taking questions # 2 is the most correct arbitrary size. I will not be taking questions
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape] dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
if spec.shape[0] is None: if spec.shape[0] is None:
@ -1159,7 +1158,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.built = True self.built = True
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec # Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config # Setting it in build() allows users to override the shape when loading a non-pretrained model from config
self._set_save_spec(self._prune_signature(self.input_signature)) self._set_save_spec(self.input_signature)
self(self.dummy_inputs, training=False) self(self.dummy_inputs, training=False)
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
@ -1300,11 +1299,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise NotImplementedError("Audio models need a manually defined input_signature") raise NotImplementedError("Audio models need a manually defined input_signature")
return sig return sig
def _prune_signature(self, signature):
"""Keeps only the keys of a given input signature that are valid for this model."""
model_inputs = list(inspect.signature(self.call).parameters)
return {key: val for key, val in signature.items() if key in model_inputs}
def serving_output(self, output): def serving_output(self, output):
""" """
Prepare the output of the saved model. Can be overridden if specific serving modifications are required. Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
@ -2423,14 +2417,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str): if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1] self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
if signatures is None: if signatures is None:
sig = self._prune_signature(self.input_signature) serving_default = self.serving.get_concrete_function(self.input_signature)
serving_default = self.serving.get_concrete_function(sig) if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
if any(spec.dtype == tf.int32 for spec in sig.values()):
int64_spec = { int64_spec = {
key: tf.TensorSpec( key: tf.TensorSpec(
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
) )
for key, spec in sig.items() for key, spec in self.input_signature.items()
} }
int64_serving = self.serving.get_concrete_function(int64_spec) int64_serving = self.serving.get_concrete_function(int64_spec)
signatures = {"serving_default": serving_default, "int64_serving": int64_serving} signatures = {"serving_default": serving_default, "int64_serving": int64_serving}

View File

@ -1168,7 +1168,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
logger.warning( logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU" "to train/fine-tune this model, you need a GPU or a TPU"
) )

View File

@ -98,7 +98,7 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
logger.warning( logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU" "to train/fine-tune this model, you need a GPU or a TPU"
) )
padding = int((kernel_size - 1) / 2) * dilation padding = int((kernel_size - 1) / 2) * dilation

View File

@ -1202,7 +1202,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
logger.warning( logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU" "to train/fine-tune this model, you need a GPU or a TPU"
) )
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None): def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):

View File

@ -1065,6 +1065,16 @@ class TFModelTesterMixin:
output_for_kw_input = model(**inputs_np) output_for_kw_input = model(**inputs_np)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input) self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_valid_input_signature_and_dummies(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
call_args = inspect.signature(model.call).parameters
for key in model.input_signature:
self.assertIn(key, call_args)
for key in model.dummy_inputs:
self.assertIn(key, call_args)
def test_resize_token_embeddings(self): def test_resize_token_embeddings(self):
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on # TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
# tf.keras.layers.Embedding # tf.keras.layers.Embedding

View File

@ -217,18 +217,17 @@ class TFCoreModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config) model = model_class(config)
class_sig = model._prune_signature(model.input_signature)
num_out = len(model(class_inputs_dict)) num_out = len(model(class_inputs_dict))
for key in list(class_inputs_dict.keys()): for key in list(class_inputs_dict.keys()):
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them # Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
if key not in class_sig: if key not in model.input_signature:
del class_inputs_dict[key] del class_inputs_dict[key]
# Check it's a tensor, in case the inputs dict has some bools in it too # Check it's a tensor, in case the inputs dict has some bools in it too
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer: elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32) class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
if set(class_inputs_dict.keys()) != set(class_sig.keys()): if set(class_inputs_dict.keys()) != set(model.input_signature.keys()):
continue # Some models have inputs that the preparation functions don't create, we skip those continue # Some models have inputs that the preparation functions don't create, we skip those
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname: