[WIP] Resnet (#15770)

* first commit

* ResNet model correctly implemented.

basic modeling + weights conversion is done

removed unused doc

mdx file

doc and conversion script

added feature_extractor to auto

test

minor changes + style + quality

doc

test

Delete process.yml

A left over from my attempt of running circleci locally

* minor changes

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* new test format

* minor changes from conversations

* minor changes from conversations

* make style + quality

* readded the tests

* test + README

* minor changes from conversations

* error in README

* make fix-copies

* removed regression for classification head

* make quality

* fixed loss control flow

* fixed loss control flow

* resolved conversations

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* READMEs

* index.mdx

* minor changes

* updated tests and models

* unused import

* outputs

* Update docs/source/model_doc/resnet.mdx

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* added embeddings_size

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* conversation

* added push to hub

* test

* embedding_size

* make fix-copies

* resolved conversations

* CI

* changed organization

* minor changes

* CI

* minor changes

* conversations

* conversation

* doc

* tests

* removed unused docstring

* conversation

* removed unused outputs

* CI

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Francesco Saverio Zuppichini 2022-03-14 19:57:55 +01:00 committed by GitHub
parent 6458236181
commit e3008c679f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1177 additions and 0 deletions

View File

@ -300,6 +300,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.

View File

@ -279,6 +279,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.

View File

@ -303,6 +303,7 @@ conda install -c huggingface transformers
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。

View File

@ -315,6 +315,7 @@ conda install -c huggingface transformers
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.

View File

@ -294,6 +294,8 @@
title: Reformer
- local: model_doc/rembert
title: RemBERT
- local: model_doc/resnet
title: ResNet
- local: model_doc/retribert
title: RetriBERT
- local: model_doc/roberta

View File

@ -124,6 +124,7 @@ conversion utilities for the following models.
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
@ -230,6 +231,7 @@ Flax), PyTorch, and/or TensorFlow.
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |

View File

@ -0,0 +1,50 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ResNet
## Overview
The ResNet model was proposed in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. Our implementation follows the small changes made by [Nvidia](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch), we apply the `stride=2` for downsampling in bottleneck's `3x3` conv and not in the first `1x1`. This is generally known as "ResNet v1.5".
ResNet introduced residual connections, they allow to train networks with an unseen number of layers (up to 1000). ResNet won the 2015 ILSVRC & COCO competition, one important milestone in deep computer vision.
The abstract from the paper is the following:
*Deeper neural networks are more difficult to train. We present a residual learning framework to ease the training of networks that are substantially deeper than those used previously. We explicitly reformulate the layers as learning residual functions with reference to the layer inputs, instead of learning unreferenced functions. We provide comprehensive empirical evidence showing that these residual networks are easier to optimize, and can gain accuracy from considerably increased depth. On the ImageNet dataset we evaluate residual nets with a depth of up to 152 layers---8x deeper than VGG nets but still having lower complexity. An ensemble of these residual nets achieves 3.57% error on the ImageNet test set. This result won the 1st place on the ILSVRC 2015 classification task. We also present analysis on CIFAR-10 with 100 and 1000 layers.
The depth of representations is of central importance for many visual recognition tasks. Solely due to our extremely deep representations, we obtain a 28% relative improvement on the COCO object detection dataset. Deep residual nets are foundations of our submissions to ILSVRC & COCO 2015 competitions, where we also won the 1st places on the tasks of ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation.*
Tips:
- One can use [`AutoFeatureExtractor`] to prepare images for the model.
The figure below illustrates the architecture of ResNet. Taken from the [original paper](https://arxiv.org/abs/1512.03385).
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/resnet_architecture.png"/>
This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).
## ResNetConfig
[[autodoc]] ResNetConfig
## ResNetModel
[[autodoc]] ResNetModel
- forward
## ResNetForImageClassification
[[autodoc]] ResNetForImageClassification
- forward

View File

@ -272,6 +272,7 @@ _import_structure = {
"models.realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig", "RealmTokenizer"],
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
"models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
"models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
"models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"],
@ -1323,6 +1324,14 @@ if is_torch_available():
"load_tf_weights_in_rembert",
]
)
_import_structure["models.resnet"].extend(
[
"RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ResNetForImageClassification",
"ResNetModel",
"ResNetPreTrainedModel",
]
)
_import_structure["models.retribert"].extend(
["RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "RetriBertModel", "RetriBertPreTrainedModel"]
)
@ -2574,6 +2583,7 @@ if TYPE_CHECKING:
from .models.realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig, RealmTokenizer
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
@ -3452,6 +3462,12 @@ if TYPE_CHECKING:
RemBertPreTrainedModel,
load_tf_weights_in_rembert,
)
from .models.resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetForImageClassification,
ResNetModel,
ResNetPreTrainedModel,
)
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
from .models.roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -850,3 +850,31 @@ class SemanticSegmentationModelOutput(ModelOutput):
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ImageClassifierOutput(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None

View File

@ -94,6 +94,7 @@ from . import (
realm,
reformer,
rembert,
resnet,
retribert,
roberta,
roformer,

View File

@ -33,6 +33,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("maskformer", "MaskFormerConfig"),
("poolformer", "PoolFormerConfig"),
("convnext", "ConvNextConfig"),
("resnet", "ResNetConfig"),
("yoso", "YosoConfig"),
("swin", "SwinConfig"),
("vilt", "ViltConfig"),
@ -133,6 +134,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -220,6 +222,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("maskformer", "MaskFormer"),
("poolformer", "PoolFormer"),
("convnext", "ConvNext"),
("resnet", "ResNet"),
("yoso", "YOSO"),
("swin", "Swin"),
("vilt", "ViLT"),

View File

@ -53,6 +53,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("vit_mae", "ViTFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
]

View File

@ -31,6 +31,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("maskformer", "MaskFormerModel"),
("poolformer", "PoolFormerModel"),
("convnext", "ConvNextModel"),
("resnet", "ResNetModel"),
("yoso", "YosoModel"),
("swin", "SwinModel"),
("vilt", "ViltModel"),
@ -294,6 +295,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
),
("swin", "SwinForImageClassification"),
("convnext", "ConvNextForImageClassification"),
("resnet", "ResNetForImageClassification"),
("poolformer", "PoolFormerForImageClassification"),
]
)

View File

@ -0,0 +1,52 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available
_import_structure = {
"configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
}
if is_torch_available():
_import_structure["modeling_resnet"] = [
"RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ResNetForImageClassification",
"ResNetModel",
"ResNetPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
if is_torch_available():
from .modeling_resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetForImageClassification,
ResNetModel,
ResNetPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ResNet model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"resnet-50": "https://huggingface.co/microsoft/resnet-50/blob/main/config.json",
}
class ResNetConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
embedding_size (`int`, *optional*, defaults to 64):
Dimensionality (hidden size) for the embedding layer.
hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
Dimensionality (hidden size) at each stage.
depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
Depth (number of layers) for each stage.
layer_type (`str`, *optional*, defaults to `"bottleneck"`):
The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
`"bottleneck"` (used for larger models like resnet-50 and above).
hidden_act (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
are supported.
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
Example:
```python
>>> from transformers import ResNetConfig, ResNetModel
>>> # Initializing a ResNet resnet-50 style configuration
>>> configuration = ResNetConfig()
>>> # Initializing a model from the resnet-50 style configuration
>>> model = ResNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "resnet"
layer_types = ["basic", "bottleneck"]
def __init__(
self,
num_channels=3,
embedding_size=64,
hidden_sizes=[256, 512, 1024, 2048],
depths=[3, 4, 6, 3],
layer_type="bottleneck",
hidden_act="relu",
downsample_in_first_stage=False,
**kwargs
):
super().__init__(**kwargs)
if layer_type not in self.layer_types:
raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
self.num_channels = num_channels
self.embedding_size = embedding_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.layer_type = layer_type
self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage

View File

@ -0,0 +1,196 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ResNet checkpoints from timm."""
import argparse
import json
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import List
import torch
import torch.nn as nn
from torch import Tensor
import timm
from huggingface_hub import cached_download, hf_hub_url
from transformers import AutoFeatureExtractor, ResNetConfig, ResNetForImageClassification
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger()
@dataclass
class Tracker:
module: nn.Module
traced: List[nn.Module] = field(default_factory=list)
handles: list = field(default_factory=list)
def _forward_hook(self, m, inputs: Tensor, outputs: Tensor):
has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
if has_not_submodules:
self.traced.append(m)
def __call__(self, x: Tensor):
for m in self.module.modules():
self.handles.append(m.register_forward_hook(self._forward_hook))
self.module(x)
list(map(lambda x: x.remove(), self.handles))
return self
@property
def parametrized(self):
# check the len of the state_dict keys to see if we have learnable params
return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced))
@dataclass
class ModuleTransfer:
src: nn.Module
dest: nn.Module
verbose: int = 0
src_skip: List = field(default_factory=list)
dest_skip: List = field(default_factory=list)
def __call__(self, x: Tensor):
"""
Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the
hood we tracked all the operations in both modules.
"""
dest_traced = Tracker(self.dest)(x).parametrized
src_traced = Tracker(self.src)(x).parametrized
src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced))
dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced))
if len(dest_traced) != len(src_traced):
raise Exception(
f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}."
)
for dest_m, src_m in zip(dest_traced, src_traced):
dest_m.load_state_dict(src_m.state_dict())
if self.verbose == 1:
print(f"Transfered from={src_m} to={dest_m}")
def convert_weight_and_push(name: str, config: ResNetConfig, save_directory: Path, push_to_hub: bool = True):
print(f"Converting {name}...")
with torch.no_grad():
from_model = timm.create_model(name, pretrained=True).eval()
our_model = ResNetForImageClassification(config).eval()
module_transfer = ModuleTransfer(src=from_model, dest=our_model)
x = torch.randn((1, 3, 224, 224))
module_transfer(x)
assert torch.allclose(from_model(x), our_model(x).logits), "The model logits don't match the original one."
checkpoint_name = f"resnet{'-'.join(name.split('resnet'))}"
print(checkpoint_name)
if push_to_hub:
our_model.push_to_hub(
repo_path_or_name=save_directory / checkpoint_name,
commit_message="Add model",
use_temp_dir=True,
)
# we can use the convnext one
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k")
feature_extractor.push_to_hub(
repo_path_or_name=save_directory / checkpoint_name,
commit_message="Add feature extractor",
use_temp_dir=True,
)
print(f"Pushed {checkpoint_name}")
def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True):
filename = "imagenet-1k-id2label.json"
num_labels = 1000
expected_shape = (1, num_labels)
repo_id = "datasets/huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label
label2id = {v: k for k, v in id2label.items()}
ImageNetPreTrainedConfig = partial(ResNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
names_to_config = {
"resnet18": ImageNetPreTrainedConfig(
depths=[2, 2, 2, 2], hidden_sizes=[64, 128, 256, 512], layer_type="basic"
),
"resnet26": ImageNetPreTrainedConfig(
depths=[2, 2, 2, 2], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck"
),
"resnet34": ImageNetPreTrainedConfig(
depths=[3, 4, 6, 3], hidden_sizes=[64, 128, 256, 512], layer_type="basic"
),
"resnet50": ImageNetPreTrainedConfig(
depths=[3, 4, 6, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck"
),
"resnet101": ImageNetPreTrainedConfig(
depths=[3, 4, 23, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck"
),
"resnet152": ImageNetPreTrainedConfig(
depths=[3, 8, 36, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck"
),
}
if model_name:
convert_weight_and_push(model_name, names_to_config[model_name], save_directory, push_to_hub)
else:
for model_name, config in names_to_config.items():
convert_weight_and_push(model_name, config, save_directory, push_to_hub)
return config, expected_shape
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default=None,
type=str,
help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=Path,
required=True,
help="Path to the output PyTorch model directory.",
)
parser.add_argument(
"--push_to_hub",
default=True,
type=bool,
required=False,
help="If True, push model and feature extractor to the hub.",
)
args = parser.parse_args()
pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)

View File

@ -0,0 +1,433 @@
# coding=utf-8
# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch ResNet model."""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import ImageClassifierOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_resnet import ResNetConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ResNetConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = ""
_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = ""
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/resnet-50",
# See all resnet models at https://huggingface.co/models?filter=resnet
]
@dataclass
class ResNetEncoderOutput(ModelOutput):
"""
ResNet encoder's output, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ResNetModelOutput(ModelOutput):
"""
ResNet model's output, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, config.hidden_sizes[-1])`):
The pooled last hidden state.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class ResNetConvLayer(nn.Sequential):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
):
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = ACT2FN[activation] if activation is not None else nn.Identity()
class ResNetEmbeddings(nn.Sequential):
"""
ResNet Embedddings (stem) composed of a single aggressive convolution.
"""
def __init__(self, num_channels: int, out_channels: int, activation: str = "relu"):
super().__init__()
self.embedder = ResNetConvLayer(num_channels, out_channels, kernel_size=7, stride=2, activation=activation)
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
class ResNetShortCut(nn.Sequential):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
class ResNetBasicLayer(nn.Module):
"""
A classic ResNet's residual layer composed by a two `3x3` convolutions.
"""
def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
super().__init__()
should_apply_shortcut = in_channels != out_channels or stride != 1
self.shortcut = (
ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
)
self.layer = nn.Sequential(
ResNetConvLayer(in_channels, out_channels, stride=stride),
ResNetConvLayer(out_channels, out_channels, activation=None),
)
self.activation = ACT2FN[activation]
def forward(self, hidden_state):
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class ResNetBottleNeckLayer(nn.Module):
"""
A classic ResNet's bottleneck layer composed by a three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
convolution faster. The last `1x1` convolution remap the reduced features to `out_channels`.
"""
def __init__(
self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", reduction: int = 4
):
super().__init__()
should_apply_shortcut = in_channels != out_channels or stride != 1
reduces_channels = out_channels // reduction
self.shortcut = (
ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
)
self.layer = nn.Sequential(
ResNetConvLayer(in_channels, reduces_channels, kernel_size=1),
ResNetConvLayer(reduces_channels, reduces_channels, stride=stride),
ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
)
self.activation = ACT2FN[activation]
def forward(self, hidden_state):
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class ResNetStage(nn.Sequential):
"""
A ResNet stage composed by stacked layers.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 2,
depth: int = 2,
layer_type: str = "basic",
activation: str = "relu",
):
super().__init__()
layer = ResNetBottleNeckLayer if layer_type == "bottleneck" else ResNetBasicLayer
self.layers = nn.Sequential(
# downsampling is done in the first layer with stride of 2
layer(in_channels, out_channels, stride=stride, activation=activation),
*[layer(out_channels, out_channels, activation=activation) for _ in range(depth - 1)],
)
class ResNetEncoder(nn.Module):
def __init__(self, config: ResNetConfig):
super().__init__()
self.stages = nn.ModuleList([])
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
self.stages.append(
ResNetStage(
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
layer_type=config.layer_type,
activation=config.hidden_act,
)
)
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
self.stages.append(
ResNetStage(
in_channels, out_channels, depth=depth, layer_type=config.layer_type, activation=config.hidden_act
)
)
def forward(
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> ResNetEncoderOutput:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return ResNetEncoderOutput(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
class ResNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ResNetModel):
module.gradient_checkpointing = value
RESNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
RESNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare ResNet model outputting raw features without any specific head on top.",
RESNET_START_DOCSTRING,
)
class ResNetModel(ResNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embedder = ResNetEmbeddings(config.num_channels, config.embedding_size, config.hidden_act)
self.encoder = ResNetEncoder(config)
self.pooler = nn.AdaptiveAvgPool2d((1, 1))
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=ResNetModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
) -> ResNetModelOutput:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embedder(pixel_values)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return ResNetModelOutput(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"""
ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
RESNET_START_DOCSTRING,
)
class ResNetForImageClassification(ResNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.resnet = ResNetModel(config)
# classification head
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
)
# initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values: Tensor = None,
labels: Tensor = None,
output_hidden_states: bool = None,
return_dict: bool = None,
) -> ImageClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return (loss,) + output if loss is not None else output
return ImageClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)

View File

@ -3334,6 +3334,30 @@ def load_tf_weights_in_rembert(*args, **kwargs):
requires_backends(load_tf_weights_in_rembert, ["torch"])
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ResNetForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ResNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ResNetPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None

0
tests/resnet/__init__.py Normal file
View File

View File

@ -0,0 +1,271 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch ResNet model. """
import inspect
import unittest
from transformers import ResNetConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ..test_configuration_common import ConfigTester
from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import ResNetForImageClassification, ResNetModel
from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class ResNetModelTester:
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ResNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
)
def create_and_check_model(self, config, pixel_values, labels):
model = ResNetModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected last hidden states: B, C, H // 32, W // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = ResNetForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
def setUp(self):
self.model_tester = ResNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="ResNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="ResNet does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config=config)
for name, module in model.named_modules():
if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
self.assertTrue(
torch.all(module.weight == 1),
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
self.assertTrue(
torch.all(module.bias == 0),
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# ResNet's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
layers_type = ["basic", "bottleneck"]
for model_class in self.all_model_classes:
for layer_type in layers_type:
config.layer_type = layer_type
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = ResNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
class ResNetModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
AutoFeatureExtractor.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = ResNetForImageClassification.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

View File

@ -31,3 +31,4 @@ src/transformers/models/plbart/modeling_plbart.py
src/transformers/generation_utils.py
docs/source/quicktour.mdx
docs/source/task_summary.mdx
src/transformers/models/resnet/modeling_resnet.py