Add ONNX support for ResNet (#17585)
* Add ONNX support for ResNet * Add ONNX test * make fix-copies
This commit is contained in:
parent
ca2a55e9df
commit
5323094a22
|
@ -72,6 +72,7 @@ Ready-made configurations include the following architectures:
|
|||
- OpenAI GPT-2
|
||||
- Perceiver
|
||||
- PLBart
|
||||
- ResNet
|
||||
- RoBERTa
|
||||
- RoFormer
|
||||
- SqueezeBERT
|
||||
|
|
|
@ -21,7 +21,9 @@ from typing import TYPE_CHECKING
|
|||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"]}
|
||||
_import_structure = {
|
||||
"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig", "ResNetOnnxConfig"]
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
@ -38,7 +40,7 @@ else:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
|
||||
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
|
|
@ -14,7 +14,13 @@
|
|||
# limitations under the License.
|
||||
""" ResNet model configuration"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Mapping
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
|
@ -89,3 +95,20 @@ class ResNetConfig(PretrainedConfig):
|
|||
self.layer_type = layer_type
|
||||
self.hidden_act = hidden_act
|
||||
self.downsample_in_first_stage = downsample_in_first_stage
|
||||
|
||||
|
||||
class ResNetOnnxConfig(OnnxConfig):
|
||||
|
||||
torch_onnx_minimum_version = version.parse("1.11")
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict(
|
||||
[
|
||||
("pixel_values", {0: "batch", 1: "sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
return 1e-3
|
||||
|
|
|
@ -318,6 +318,11 @@ class FeaturesManager:
|
|||
"sequence-classification",
|
||||
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
|
||||
),
|
||||
"resnet": supported_features_mapping(
|
||||
"default",
|
||||
"image-classification",
|
||||
onnx_config_cls="models.resnet.ResNetOnnxConfig",
|
||||
),
|
||||
"roberta": supported_features_mapping(
|
||||
"default",
|
||||
"masked-lm",
|
||||
|
|
|
@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||
("convbert", "YituTech/conv-bert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
("electra", "google/electra-base-generator"),
|
||||
("resnet", "microsoft/resnet-50"),
|
||||
("roberta", "roberta-base"),
|
||||
("roformer", "junnyu/roformer_chinese_base"),
|
||||
("squeezebert", "squeezebert/squeezebert-uncased"),
|
||||
|
|
Loading…
Reference in New Issue