Add ONNX support for ResNet (#17585)

* Add ONNX support for ResNet

* Add ONNX test

* make fix-copies
This commit is contained in:
regisss 2022-06-09 14:44:27 +02:00 committed by GitHub
parent ca2a55e9df
commit 5323094a22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 2 deletions

View File

@ -72,6 +72,7 @@ Ready-made configurations include the following architectures:
- OpenAI GPT-2
- Perceiver
- PLBart
- ResNet
- RoBERTa
- RoFormer
- SqueezeBERT

View File

@ -21,7 +21,9 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"]}
_import_structure = {
"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig", "ResNetOnnxConfig"]
}
try:
if not is_torch_available():
@ -38,7 +40,7 @@ else:
if TYPE_CHECKING:
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
try:
if not is_torch_available():

View File

@ -14,7 +14,13 @@
# limitations under the License.
""" ResNet model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
@ -89,3 +95,20 @@ class ResNetConfig(PretrainedConfig):
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage
class ResNetOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-3

View File

@ -318,6 +318,11 @@ class FeaturesManager:
"sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
),
"resnet": supported_features_mapping(
"default",
"image-classification",
onnx_config_cls="models.resnet.ResNetOnnxConfig",
),
"roberta": supported_features_mapping(
"default",
"masked-lm",

View File

@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("convbert", "YituTech/conv-bert-base"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"),
("roformer", "junnyu/roformer_chinese_base"),
("squeezebert", "squeezebert/squeezebert-uncased"),