diff --git a/README.md b/README.md index f3f2a65017..dd44df703e 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. +1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. diff --git a/README_ko.md b/README_ko.md index 73fa63db8b..8c9bb7412e 100644 --- a/README_ko.md +++ b/README_ko.md @@ -279,6 +279,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. +1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. diff --git a/README_zh-hans.md b/README_zh-hans.md index eb3fc362e6..7b146438b3 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -303,6 +303,7 @@ conda install -c huggingface transformers 1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。 +1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 4d1f95c2fe..885efb9af5 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -315,6 +315,7 @@ conda install -c huggingface transformers 1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. +1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 11caf4510d..e9b846d5b6 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -294,6 +294,8 @@ title: Reformer - local: model_doc/rembert title: RemBERT + - local: model_doc/resnet + title: ResNet - local: model_doc/retribert title: RetriBERT - local: model_doc/roberta diff --git a/docs/source/index.mdx b/docs/source/index.mdx index f42ed3277c..7995d085c0 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -124,6 +124,7 @@ conversion utilities for the following models. 1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. 1. **[Reformer](model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. 1. **[RemBERT](model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. +1. **[ResNet](model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. 1. **[SegFormer](model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. @@ -230,6 +231,7 @@ Flax), PyTorch, and/or TensorFlow. | Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | +| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/docs/source/model_doc/resnet.mdx b/docs/source/model_doc/resnet.mdx new file mode 100644 index 0000000000..88131c24ba --- /dev/null +++ b/docs/source/model_doc/resnet.mdx @@ -0,0 +1,50 @@ + + +# ResNet + +## Overview + +The ResNet model was proposed in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. Our implementation follows the small changes made by [Nvidia](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch), we apply the `stride=2` for downsampling in bottleneck's `3x3` conv and not in the first `1x1`. This is generally known as "ResNet v1.5". + +ResNet introduced residual connections, they allow to train networks with an unseen number of layers (up to 1000). ResNet won the 2015 ILSVRC & COCO competition, one important milestone in deep computer vision. + +The abstract from the paper is the following: + +*Deeper neural networks are more difficult to train. We present a residual learning framework to ease the training of networks that are substantially deeper than those used previously. We explicitly reformulate the layers as learning residual functions with reference to the layer inputs, instead of learning unreferenced functions. We provide comprehensive empirical evidence showing that these residual networks are easier to optimize, and can gain accuracy from considerably increased depth. On the ImageNet dataset we evaluate residual nets with a depth of up to 152 layers---8x deeper than VGG nets but still having lower complexity. An ensemble of these residual nets achieves 3.57% error on the ImageNet test set. This result won the 1st place on the ILSVRC 2015 classification task. We also present analysis on CIFAR-10 with 100 and 1000 layers. +The depth of representations is of central importance for many visual recognition tasks. Solely due to our extremely deep representations, we obtain a 28% relative improvement on the COCO object detection dataset. Deep residual nets are foundations of our submissions to ILSVRC & COCO 2015 competitions, where we also won the 1st places on the tasks of ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation.* + +Tips: + +- One can use [`AutoFeatureExtractor`] to prepare images for the model. + +The figure below illustrates the architecture of ResNet. Taken from the [original paper](https://arxiv.org/abs/1512.03385). + + + +This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks). + +## ResNetConfig + +[[autodoc]] ResNetConfig + + +## ResNetModel + +[[autodoc]] ResNetModel + - forward + + +## ResNetForImageClassification + +[[autodoc]] ResNetForImageClassification + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 19dddc1b38..9b13bc16b0 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -272,6 +272,7 @@ _import_structure = { "models.realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig", "RealmTokenizer"], "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], "models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"], + "models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"], "models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"], "models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"], "models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"], @@ -1323,6 +1324,14 @@ if is_torch_available(): "load_tf_weights_in_rembert", ] ) + _import_structure["models.resnet"].extend( + [ + "RESNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "ResNetForImageClassification", + "ResNetModel", + "ResNetPreTrainedModel", + ] + ) _import_structure["models.retribert"].extend( ["RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "RetriBertModel", "RetriBertPreTrainedModel"] ) @@ -2574,6 +2583,7 @@ if TYPE_CHECKING: from .models.realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig, RealmTokenizer from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig + from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer @@ -3452,6 +3462,12 @@ if TYPE_CHECKING: RemBertPreTrainedModel, load_tf_weights_in_rembert, ) + from .models.resnet import ( + RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, + ResNetForImageClassification, + ResNetModel, + ResNetPreTrainedModel, + ) from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel from .models.roberta import ( ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index b5979e23eb..adf00d813a 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -850,3 +850,31 @@ class SemanticSegmentationModelOutput(ModelOutput): logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ImageClassifierOutput(ModelOutput): + """ + Base class for outputs of image classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at + the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 8d8c07902c..545a1f4e62 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -94,6 +94,7 @@ from . import ( realm, reformer, rembert, + resnet, retribert, roberta, roformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1591d7adf6..cf1a6d252d 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -33,6 +33,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("maskformer", "MaskFormerConfig"), ("poolformer", "PoolFormerConfig"), ("convnext", "ConvNextConfig"), + ("resnet", "ResNetConfig"), ("yoso", "YosoConfig"), ("swin", "SwinConfig"), ("vilt", "ViltConfig"), @@ -133,6 +134,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -220,6 +222,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("maskformer", "MaskFormer"), ("poolformer", "PoolFormer"), ("convnext", "ConvNext"), + ("resnet", "ResNet"), ("yoso", "YOSO"), ("swin", "Swin"), ("vilt", "ViLT"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 7f1a448200..ff0b1ac36e 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -53,6 +53,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ("vit_mae", "ViTFeatureExtractor"), ("segformer", "SegformerFeatureExtractor"), ("convnext", "ConvNextFeatureExtractor"), + ("resnet", "ConvNextFeatureExtractor"), ("poolformer", "PoolFormerFeatureExtractor"), ("maskformer", "MaskFormerFeatureExtractor"), ] diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 92f1a1f8a3..72895b7e8d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -31,6 +31,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("maskformer", "MaskFormerModel"), ("poolformer", "PoolFormerModel"), ("convnext", "ConvNextModel"), + ("resnet", "ResNetModel"), ("yoso", "YosoModel"), ("swin", "SwinModel"), ("vilt", "ViltModel"), @@ -294,6 +295,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ), ("swin", "SwinForImageClassification"), ("convnext", "ConvNextForImageClassification"), + ("resnet", "ResNetForImageClassification"), ("poolformer", "PoolFormerForImageClassification"), ] ) diff --git a/src/transformers/models/resnet/__init__.py b/src/transformers/models/resnet/__init__.py new file mode 100644 index 0000000000..c561729d54 --- /dev/null +++ b/src/transformers/models/resnet/__init__.py @@ -0,0 +1,52 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...file_utils import _LazyModule, is_torch_available + + +_import_structure = { + "configuration_resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"], +} + +if is_torch_available(): + _import_structure["modeling_resnet"] = [ + "RESNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "ResNetForImageClassification", + "ResNetModel", + "ResNetPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig + + if is_torch_available(): + from .modeling_resnet import ( + RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, + ResNetForImageClassification, + ResNetModel, + ResNetPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/resnet/configuration_resnet.py b/src/transformers/models/resnet/configuration_resnet.py new file mode 100644 index 0000000000..a95562db5e --- /dev/null +++ b/src/transformers/models/resnet/configuration_resnet.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" ResNet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "resnet-50": "https://huggingface.co/microsoft/resnet-50/blob/main/config.json", +} + + +class ResNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an + ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [resnet-50](https://huggingface.co/microsoft/resnet-50) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embedding_size (`int`, *optional*, defaults to 64): + Dimensionality (hidden size) for the embedding layer. + hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`): + Depth (number of layers) for each stage. + layer_type (`str`, *optional*, defaults to `"bottleneck"`): + The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or + `"bottleneck"` (used for larger models like resnet-50 and above). + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` + are supported. + downsample_in_first_stage (`bool`, *optional*, defaults to `False`): + If `True`, the first stage will downsample the inputs using a `stride` of 2. + + Example: + ```python + >>> from transformers import ResNetConfig, ResNetModel + + >>> # Initializing a ResNet resnet-50 style configuration + >>> configuration = ResNetConfig() + >>> # Initializing a model from the resnet-50 style configuration + >>> model = ResNetModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = "resnet" + layer_types = ["basic", "bottleneck"] + + def __init__( + self, + num_channels=3, + embedding_size=64, + hidden_sizes=[256, 512, 1024, 2048], + depths=[3, 4, 6, 3], + layer_type="bottleneck", + hidden_act="relu", + downsample_in_first_stage=False, + **kwargs + ): + super().__init__(**kwargs) + if layer_type not in self.layer_types: + raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}") + self.num_channels = num_channels + self.embedding_size = embedding_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.layer_type = layer_type + self.hidden_act = hidden_act + self.downsample_in_first_stage = downsample_in_first_stage diff --git a/src/transformers/models/resnet/convert_resnet_to_pytorch.py b/src/transformers/models/resnet/convert_resnet_to_pytorch.py new file mode 100644 index 0000000000..47af08861f --- /dev/null +++ b/src/transformers/models/resnet/convert_resnet_to_pytorch.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ResNet checkpoints from timm.""" + + +import argparse +import json +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import List + +import torch +import torch.nn as nn +from torch import Tensor + +import timm +from huggingface_hub import cached_download, hf_hub_url +from transformers import AutoFeatureExtractor, ResNetConfig, ResNetForImageClassification +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger() + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + self.traced.append(m) + + def __call__(self, x: Tensor): + for m in self.module.modules(): + self.handles.append(m.register_forward_hook(self._forward_hook)) + self.module(x) + list(map(lambda x: x.remove(), self.handles)) + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced)) + + +@dataclass +class ModuleTransfer: + src: nn.Module + dest: nn.Module + verbose: int = 0 + src_skip: List = field(default_factory=list) + dest_skip: List = field(default_factory=list) + + def __call__(self, x: Tensor): + """ + Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the + hood we tracked all the operations in both modules. + """ + dest_traced = Tracker(self.dest)(x).parametrized + src_traced = Tracker(self.src)(x).parametrized + + src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced)) + dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced)) + + if len(dest_traced) != len(src_traced): + raise Exception( + f"Numbers of operations are different. Source module has {len(src_traced)} operations while destination module has {len(dest_traced)}." + ) + + for dest_m, src_m in zip(dest_traced, src_traced): + dest_m.load_state_dict(src_m.state_dict()) + if self.verbose == 1: + print(f"Transfered from={src_m} to={dest_m}") + + +def convert_weight_and_push(name: str, config: ResNetConfig, save_directory: Path, push_to_hub: bool = True): + print(f"Converting {name}...") + with torch.no_grad(): + from_model = timm.create_model(name, pretrained=True).eval() + our_model = ResNetForImageClassification(config).eval() + module_transfer = ModuleTransfer(src=from_model, dest=our_model) + x = torch.randn((1, 3, 224, 224)) + module_transfer(x) + + assert torch.allclose(from_model(x), our_model(x).logits), "The model logits don't match the original one." + + checkpoint_name = f"resnet{'-'.join(name.split('resnet'))}" + print(checkpoint_name) + + if push_to_hub: + our_model.push_to_hub( + repo_path_or_name=save_directory / checkpoint_name, + commit_message="Add model", + use_temp_dir=True, + ) + + # we can use the convnext one + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/convnext-base-224-22k-1k") + feature_extractor.push_to_hub( + repo_path_or_name=save_directory / checkpoint_name, + commit_message="Add feature extractor", + use_temp_dir=True, + ) + + print(f"Pushed {checkpoint_name}") + + +def convert_weights_and_push(save_directory: Path, model_name: str = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + expected_shape = (1, num_labels) + + repo_id = "datasets/huggingface/label-files" + num_labels = num_labels + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(ResNetConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "resnet18": ImageNetPreTrainedConfig( + depths=[2, 2, 2, 2], hidden_sizes=[64, 128, 256, 512], layer_type="basic" + ), + "resnet26": ImageNetPreTrainedConfig( + depths=[2, 2, 2, 2], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + "resnet34": ImageNetPreTrainedConfig( + depths=[3, 4, 6, 3], hidden_sizes=[64, 128, 256, 512], layer_type="basic" + ), + "resnet50": ImageNetPreTrainedConfig( + depths=[3, 4, 6, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + "resnet101": ImageNetPreTrainedConfig( + depths=[3, 4, 23, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + "resnet152": ImageNetPreTrainedConfig( + depths=[3, 8, 36, 3], hidden_sizes=[256, 512, 1024, 2048], layer_type="bottleneck" + ), + } + + if model_name: + convert_weight_and_push(model_name, names_to_config[model_name], save_directory, push_to_hub) + else: + for model_name, config in names_to_config.items(): + convert_weight_and_push(model_name, config, save_directory, push_to_hub) + return config, expected_shape + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default=None, + type=str, + help="The name of the model you wish to convert, it must be one of the supported resnet* architecture, currently: resnet18,26,34,50,101,152. If `None`, all of them will the converted.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and feature extractor to the hub.", + ) + + args = parser.parse_args() + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py new file mode 100644 index 0000000000..e26a878459 --- /dev/null +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -0,0 +1,433 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ResNet model.""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ImageClassifierOutput, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_resnet import ResNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ResNetConfig" +_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "" +_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "" +_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'" + +RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/resnet-50", + # See all resnet models at https://huggingface.co/models?filter=resnet +] + + +@dataclass +class ResNetEncoderOutput(ModelOutput): + """ + ResNet encoder's output, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ResNetModelOutput(ModelOutput): + """ + ResNet model's output, with potential hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, config.hidden_sizes[-1])`): + The pooled last hidden state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, num_channels, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class ResNetConvLayer(nn.Sequential): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu" + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + self.act = ACT2FN[activation] if activation is not None else nn.Identity() + + +class ResNetEmbeddings(nn.Sequential): + """ + ResNet Embedddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, num_channels: int, out_channels: int, activation: str = "relu"): + super().__init__() + self.embedder = ResNetConvLayer(num_channels, out_channels, kernel_size=7, stride=2, activation=activation) + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + +class ResNetShortCut(nn.Sequential): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + + +class ResNetBasicLayer(nn.Module): + """ + A classic ResNet's residual layer composed by a two `3x3` convolutions. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + self.shortcut = ( + ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + ResNetConvLayer(in_channels, out_channels, stride=stride), + ResNetConvLayer(out_channels, out_channels, activation=None), + ) + self.activation = ACT2FN[activation] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetBottleNeckLayer(nn.Module): + """ + A classic ResNet's bottleneck layer composed by a three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remap the reduced features to `out_channels`. + """ + + def __init__( + self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", reduction: int = 4 + ): + super().__init__() + should_apply_shortcut = in_channels != out_channels or stride != 1 + reduces_channels = out_channels // reduction + self.shortcut = ( + ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity() + ) + self.layer = nn.Sequential( + ResNetConvLayer(in_channels, reduces_channels, kernel_size=1), + ResNetConvLayer(reduces_channels, reduces_channels, stride=stride), + ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None), + ) + self.activation = ACT2FN[activation] + + def forward(self, hidden_state): + residual = hidden_state + hidden_state = self.layer(hidden_state) + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetStage(nn.Sequential): + """ + A ResNet stage composed by stacked layers. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 2, + depth: int = 2, + layer_type: str = "basic", + activation: str = "relu", + ): + super().__init__() + + layer = ResNetBottleNeckLayer if layer_type == "bottleneck" else ResNetBasicLayer + + self.layers = nn.Sequential( + # downsampling is done in the first layer with stride of 2 + layer(in_channels, out_channels, stride=stride, activation=activation), + *[layer(out_channels, out_channels, activation=activation) for _ in range(depth - 1)], + ) + + +class ResNetEncoder(nn.Module): + def __init__(self, config: ResNetConfig): + super().__init__() + self.stages = nn.ModuleList([]) + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + self.stages.append( + ResNetStage( + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + layer_type=config.layer_type, + activation=config.hidden_act, + ) + ) + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + self.stages.append( + ResNetStage( + in_channels, out_channels, depth=depth, layer_type=config.layer_type, activation=config.hidden_act + ) + ) + + def forward( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> ResNetEncoderOutput: + hidden_states = () if output_hidden_states else None + + for stage_module in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return ResNetEncoderOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class ResNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ResNetConfig + base_model_prefix = "resnet" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ResNetModel): + module.gradient_checkpointing = value + + +RESNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ResNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +RESNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See + [`AutoFeatureExtractor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ResNet model outputting raw features without any specific head on top.", + RESNET_START_DOCSTRING, +) +class ResNetModel(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.embedder = ResNetEmbeddings(config.num_channels, config.embedding_size, config.hidden_act) + self.encoder = ResNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=ResNetModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> ResNetModelOutput: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + embedding_output = self.embedder(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict + ) + + last_hidden_state = encoder_outputs[0] + + pooled_output = self.pooler(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return ResNetModelOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + RESNET_START_DOCSTRING, +) +class ResNetForImageClassification(ResNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.resnet = ResNetModel(config) + # classification head + self.classifier = nn.Sequential( + nn.Flatten(), + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(), + ) + # initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Tensor = None, + labels: Tensor = None, + output_hidden_states: bool = None, + return_dict: bool = None, + ) -> ImageClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f57cb82a78..66bab03dc8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3334,6 +3334,30 @@ def load_tf_weights_in_rembert(*args, **kwargs): requires_backends(load_tf_weights_in_rembert, ["torch"]) +RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class ResNetForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ResNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ResNetPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/resnet/__init__.py b/tests/resnet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/resnet/test_modeling_resnet.py b/tests/resnet/test_modeling_resnet.py new file mode 100644 index 0000000000..85367155f2 --- /dev/null +++ b/tests/resnet/test_modeling_resnet.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch ResNet model. """ + + +import inspect +import unittest + +from transformers import ResNetConfig +from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from ..test_configuration_common import ConfigTester +from ..test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import ResNetForImageClassification, ResNetModel + from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoFeatureExtractor + + +class ResNetModelTester: + def __init__( + self, + parent, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[10, 20, 30, 40], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return ResNetConfig( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = ResNetModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected last hidden states: B, C, H // 32, W // 32 + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32), + ) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.num_labels + model = ResNetForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class ResNetModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (ResNetModel, ResNetForImageClassification) if is_torch_available() else () + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = False + + def setUp(self): + self.model_tester = ResNetModelTester(self) + self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="ResNet does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="ResNet does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, module in model.named_modules(): + if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): + self.assertTrue( + torch.all(module.weight == 1), + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + self.assertTrue( + torch.all(module.bias == 0), + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_stages = self.model_tester.num_stages + self.assertEqual(len(hidden_states), expected_num_stages + 1) + + # ResNet's feature maps are of shape (batch_size, num_channels, height, width) + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [self.model_tester.image_size // 4, self.model_tester.image_size // 4], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + layers_type = ["basic", "bottleneck"] + for model_class in self.all_model_classes: + for layer_type in layers_type: + config.layer_type = layer_type + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = ResNetModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_torch +@require_vision +class ResNetModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + AutoFeatureExtractor.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = ResNetForImageClassification.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 1bbba630c2..3fe4b05853 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -31,3 +31,4 @@ src/transformers/models/plbart/modeling_plbart.py src/transformers/generation_utils.py docs/source/quicktour.mdx docs/source/task_summary.mdx +src/transformers/models/resnet/modeling_resnet.py