Add test for proper TF input signatures (#24320)
* Add test for proper input signatures * No more signature pruning * Test the dummy inputs are valid too * fine-tine -> fine-tune * Fix indent in test_dataset_conversion
This commit is contained in:
parent
bdfd57d1d1
commit
9138995025
|
@ -1122,8 +1122,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
`Dict[str, tf.Tensor]`: The dummy inputs.
|
`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
"""
|
"""
|
||||||
dummies = {}
|
dummies = {}
|
||||||
sig = self._prune_signature(self.input_signature)
|
for key, spec in self.input_signature.items():
|
||||||
for key, spec in sig.items():
|
|
||||||
# 2 is the most correct arbitrary size. I will not be taking questions
|
# 2 is the most correct arbitrary size. I will not be taking questions
|
||||||
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
|
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
|
||||||
if spec.shape[0] is None:
|
if spec.shape[0] is None:
|
||||||
|
@ -1159,7 +1158,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
self.built = True
|
self.built = True
|
||||||
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
|
||||||
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
|
# Setting it in build() allows users to override the shape when loading a non-pretrained model from config
|
||||||
self._set_save_spec(self._prune_signature(self.input_signature))
|
self._set_save_spec(self.input_signature)
|
||||||
self(self.dummy_inputs, training=False)
|
self(self.dummy_inputs, training=False)
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
|
@ -1300,11 +1299,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
raise NotImplementedError("Audio models need a manually defined input_signature")
|
raise NotImplementedError("Audio models need a manually defined input_signature")
|
||||||
return sig
|
return sig
|
||||||
|
|
||||||
def _prune_signature(self, signature):
|
|
||||||
"""Keeps only the keys of a given input signature that are valid for this model."""
|
|
||||||
model_inputs = list(inspect.signature(self.call).parameters)
|
|
||||||
return {key: val for key, val in signature.items() if key in model_inputs}
|
|
||||||
|
|
||||||
def serving_output(self, output):
|
def serving_output(self, output):
|
||||||
"""
|
"""
|
||||||
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
|
Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
|
||||||
|
@ -2423,14 +2417,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||||
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
|
if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
|
||||||
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
|
self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
|
||||||
if signatures is None:
|
if signatures is None:
|
||||||
sig = self._prune_signature(self.input_signature)
|
serving_default = self.serving.get_concrete_function(self.input_signature)
|
||||||
serving_default = self.serving.get_concrete_function(sig)
|
if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
|
||||||
if any(spec.dtype == tf.int32 for spec in sig.values()):
|
|
||||||
int64_spec = {
|
int64_spec = {
|
||||||
key: tf.TensorSpec(
|
key: tf.TensorSpec(
|
||||||
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
|
shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
|
||||||
)
|
)
|
||||||
for key, spec in sig.items()
|
for key, spec in self.input_signature.items()
|
||||||
}
|
}
|
||||||
int64_serving = self.serving.get_concrete_function(int64_spec)
|
int64_serving = self.serving.get_concrete_function(int64_spec)
|
||||||
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
|
signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
|
||||||
|
|
|
@ -1168,7 +1168,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
|
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
|
||||||
"to train/fine-tine this model, you need a GPU or a TPU"
|
"to train/fine-tune this model, you need a GPU or a TPU"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ class TFMobileViTConvLayer(tf.keras.layers.Layer):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
|
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
|
||||||
"to train/fine-tine this model, you need a GPU or a TPU"
|
"to train/fine-tune this model, you need a GPU or a TPU"
|
||||||
)
|
)
|
||||||
|
|
||||||
padding = int((kernel_size - 1) / 2) * dilation
|
padding = int((kernel_size - 1) / 2) * dilation
|
||||||
|
|
|
@ -1202,7 +1202,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
|
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
|
||||||
"to train/fine-tine this model, you need a GPU or a TPU"
|
"to train/fine-tune this model, you need a GPU or a TPU"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
|
def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None):
|
||||||
|
|
|
@ -1065,6 +1065,16 @@ class TFModelTesterMixin:
|
||||||
output_for_kw_input = model(**inputs_np)
|
output_for_kw_input = model(**inputs_np)
|
||||||
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
|
||||||
|
|
||||||
|
def test_valid_input_signature_and_dummies(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
call_args = inspect.signature(model.call).parameters
|
||||||
|
for key in model.input_signature:
|
||||||
|
self.assertIn(key, call_args)
|
||||||
|
for key in model.dummy_inputs:
|
||||||
|
self.assertIn(key, call_args)
|
||||||
|
|
||||||
def test_resize_token_embeddings(self):
|
def test_resize_token_embeddings(self):
|
||||||
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
|
# TODO (joao): after the embeddings refactor is complete, rework this test so as to rely exclusively on
|
||||||
# tf.keras.layers.Embedding
|
# tf.keras.layers.Embedding
|
||||||
|
|
|
@ -217,18 +217,17 @@ class TFCoreModelTesterMixin:
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
class_sig = model._prune_signature(model.input_signature)
|
|
||||||
num_out = len(model(class_inputs_dict))
|
num_out = len(model(class_inputs_dict))
|
||||||
|
|
||||||
for key in list(class_inputs_dict.keys()):
|
for key in list(class_inputs_dict.keys()):
|
||||||
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
|
# Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
|
||||||
if key not in class_sig:
|
if key not in model.input_signature:
|
||||||
del class_inputs_dict[key]
|
del class_inputs_dict[key]
|
||||||
# Check it's a tensor, in case the inputs dict has some bools in it too
|
# Check it's a tensor, in case the inputs dict has some bools in it too
|
||||||
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
|
||||||
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
|
||||||
|
|
||||||
if set(class_inputs_dict.keys()) != set(class_sig.keys()):
|
if set(class_inputs_dict.keys()) != set(model.input_signature.keys()):
|
||||||
continue # Some models have inputs that the preparation functions don't create, we skip those
|
continue # Some models have inputs that the preparation functions don't create, we skip those
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
|
Loading…
Reference in New Issue