From 5323094a222e41529e9dcc2e5534e5053ece81e1 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Thu, 9 Jun 2022 14:44:27 +0200 Subject: [PATCH] Add ONNX support for ResNet (#17585) * Add ONNX support for ResNet * Add ONNX test * make fix-copies --- docs/source/en/serialization.mdx | 1 + src/transformers/models/resnet/__init__.py | 6 +++-- .../models/resnet/configuration_resnet.py | 23 +++++++++++++++++++ src/transformers/onnx/features.py | 5 ++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index e6ba52e39f..bf172bd199 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -72,6 +72,7 @@ Ready-made configurations include the following architectures: - OpenAI GPT-2 - Perceiver - PLBart +- ResNet - RoBERTa - RoFormer - SqueezeBERT diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py index b5ab2d38fb..e1c0a9ec84 100644 --- a/src/transformers/models/resnet/__init__.py +++ b/src/transformers/models/resnet/__init__.py @@ -21,7 +21,9 @@ from typing import TYPE_CHECKING from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available -_import_structure = {"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"]} +_import_structure = { + "configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig", "ResNetOnnxConfig"] +} try: if not is_torch_available(): @@ -38,7 +40,7 @@ else: if TYPE_CHECKING: - from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig + from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig try: if not is_torch_available(): diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py index 8e5f6e656d..9bfc694bb1 100644 --- a/src/transformers/models/resnet/configuration_resnet.py +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -14,7 +14,13 @@ # limitations under the License. """ ResNet model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -89,3 +95,20 @@ class ResNetConfig(PretrainedConfig): self.layer_type = layer_type self.hidden_act = hidden_act self.downsample_in_first_stage = downsample_in_first_stage + + +class ResNetOnnxConfig(OnnxConfig): + + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-3 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index b97ee4aa7d..d29831c36b 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -318,6 +318,11 @@ class FeaturesManager: "sequence-classification", onnx_config_cls="models.perceiver.PerceiverOnnxConfig", ), + "resnet": supported_features_mapping( + "default", + "image-classification", + onnx_config_cls="models.resnet.ResNetOnnxConfig", + ), "roberta": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 2b5a02bd59..2f73169def 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = { ("convbert", "YituTech/conv-bert-base"), ("distilbert", "distilbert-base-cased"), ("electra", "google/electra-base-generator"), + ("resnet", "microsoft/resnet-50"), ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"), ("squeezebert", "squeezebert/squeezebert-uncased"),