Enabling custom TF signature draft (#19249)

* Custom TF signature draft

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Adding tf signature tests

* Fixing signature check and adding asserts

* fixing model load path

* Adjusting signature tests

* Formatting file

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Dimitre Oliveira <dimitreoliveira@Dimitres-MacBook-Air.local>
This commit is contained in:
Dimitre Oliveira 2022-10-11 06:56:08 -03:00 committed by GitHub
parent 10100979ed
commit df8faba4db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 1 deletions

View File

@ -2097,6 +2097,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
saved_model=False,
version=1,
push_to_hub=False,
signatures=None,
max_shard_size: Union[int, str] = "10GB",
create_pr: bool = False,
**kwargs
@ -2118,6 +2119,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
signatures (`dict` or `tf.function`, *optional*):
Model's signature used for serving. This will be passed to the `signatures` argument of model.save().
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
@ -2148,8 +2151,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
files_timestamps = self._get_files_timestamps(save_directory)
if saved_model:
if signatures is None:
signatures = self.serving
saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
self.save(saved_model_dir, include_optimizer=False, signatures=self.serving)
self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
logger.info(f"Saved model created in {saved_model_dir}")
# Save configuration file

View File

@ -2216,6 +2216,46 @@ class UtilsFunctionsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Short custom TF signature function.
# `input_signature` is specific to BERT.
@tf.function(
input_signature=[
[
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
]
]
)
def serving_fn(input):
return model(input)
# Using default signature (default behavior) overrides 'serving_default'
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
# Providing custom signature function
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
# Providing multiple custom signature function
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
saved_model=True,
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
)
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
@require_tf
@is_staging_test