diff --git a/README.md b/README.md
index f1e205702a..008089f8fe 100644
--- a/README.md
+++ b/README.md
@@ -311,6 +311,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER
AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[ViTMAE)](https://huggingface.co/docs/transformers/master/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
diff --git a/README_ko.md b/README_ko.md
index 21d06e51cc..303503b4c8 100644
--- a/README_ko.md
+++ b/README_ko.md
@@ -289,6 +289,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
1. **[ViTMAE)](https://huggingface.co/docs/transformers/master/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
diff --git a/README_zh-hans.md b/README_zh-hans.md
index 649e7bba3e..851ff9f4c5 100644
--- a/README_zh-hans.md
+++ b/README_zh-hans.md
@@ -313,6 +313,7 @@ conda install -c huggingface transformers
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (来自 Microsoft) 伴随论文 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 由 Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 发布。
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (来自 Microsoft Research) 伴随论文 [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) 由 Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang 发布。
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (来自 Microsoft Research) 伴随论文 [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) 由 Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu 发布。
+1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (来自 NAVER AI Lab/Kakao Enterprise/Kakao Brain) 伴随论文 [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) 由 Wonjae Kim, Bokyung Son, Ildoo Kim 发布。
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (来自 UCLA NLP) 伴随论文 [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) 由 Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang 发布。
1. **[ViTMAE)](https://huggingface.co/docs/transformers/master/model_doc/vit_mae)** (来自 Meta AI) 伴随论文 [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) 由 Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick 发布。
diff --git a/README_zh-hant.md b/README_zh-hant.md
index 241900c532..0a0ab66c91 100644
--- a/README_zh-hant.md
+++ b/README_zh-hant.md
@@ -325,6 +325,7 @@ conda install -c huggingface transformers
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](https://huggingface.co/docs/transformers/model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[ViLT)](https://huggingface.co/docs/transformers/master/model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[VisualBERT](https://huggingface.co/docs/transformers/model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
1. **[ViTMAE)](https://huggingface.co/docs/transformers/master/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index cde8845bb9..b0ef21e522 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -282,6 +282,8 @@
title: UniSpeech
- local: model_doc/unispeech-sat
title: UniSpeech-SAT
+ - local: model_doc/vilt
+ title: ViLT
- local: model_doc/vision-encoder-decoder
title: Vision Encoder Decoder Models
- local: model_doc/vision-text-dual-encoder
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index 7306ff1a4d..94fe8fe8a8 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -170,6 +170,7 @@ conversion utilities for the following models.
1. **[TrOCR](model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[UniSpeech](model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang.
1. **[UniSpeechSat](model_doc/unispeech-sat)** (from Microsoft Research) released with the paper [UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen, Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
+1. **[ViLT)](model_doc/vilt)** (from NAVER AI Lab/Kakao Enterprise/Kakao Brain) released with the paper [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334) by Wonjae Kim, Bokyung Son, Ildoo Kim.
1. **[Vision Transformer (ViT)](model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[ViTMAE)](model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick.
1. **[VisualBERT](model_doc/visual_bert)** (from UCLA NLP) released with the paper [VisualBERT: A Simple and Performant Baseline for Vision and Language](https://arxiv.org/pdf/1908.03557) by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
@@ -266,6 +267,7 @@ Flax), PyTorch, and/or TensorFlow.
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeechSat | ❌ | ❌ | ✅ | ❌ | ❌ |
+| ViLT | ❌ | ❌ | ✅ | ❌ | ❌ |
| Vision Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
diff --git a/docs/source/model_doc/vilt.mdx b/docs/source/model_doc/vilt.mdx
new file mode 100644
index 0000000000..9170d84ead
--- /dev/null
+++ b/docs/source/model_doc/vilt.mdx
@@ -0,0 +1,87 @@
+
+
+# ViLT
+
+## Overview
+
+The ViLT model was proposed in [ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision](https://arxiv.org/abs/2102.03334)
+by Wonjae Kim, Bokyung Son, Ildoo Kim. ViLT incorporates text embeddings into a Vision Transformer (ViT), allowing it to have a minimal design
+for Vision-and-Language Pre-training (VLP).
+
+The abstract from the paper is the following:
+
+*Vision-and-Language Pre-training (VLP) has improved performance on various joint vision-and-language downstream tasks.
+Current approaches to VLP heavily rely on image feature extraction processes, most of which involve region supervision
+(e.g., object detection) and the convolutional architecture (e.g., ResNet). Although disregarded in the literature, we
+find it problematic in terms of both (1) efficiency/speed, that simply extracting input features requires much more
+computation than the multimodal interaction steps; and (2) expressive power, as it is upper bounded to the expressive
+power of the visual embedder and its predefined visual vocabulary. In this paper, we present a minimal VLP model,
+Vision-and-Language Transformer (ViLT), monolithic in the sense that the processing of visual inputs is drastically
+simplified to just the same convolution-free manner that we process textual inputs. We show that ViLT is up to tens of
+times faster than previous VLP models, yet with competitive or better downstream task performance.*
+
+Tips:
+
+- ViLT is a model that takes both `pixel_values` and `input_ids` as input. One can use [`ViltProcessor`] to prepare data for the model.
+ This processor wraps a feature extractor (for the image modality) and a tokenizer (for the language modality) into one.
+- ViLT is trained with images of various sizes: the authors resize the shorter edge of input images to 384 and limit the longer edge to
+ under 640 while preserving the aspect ratio. To make batching of images possible, the authors use a `pixel_mask` that indicates
+ which pixel values are real and which are padding. [`ViltProcessor`] automatically creates this for you.
+- The design of ViLT is very similar to that of a standard Vision Transformer (ViT). The only difference is that the model includes
+ additional embedding layers for the language modality.
+
+
+
+ ViLT architecture. Taken from the original paper.
+
+This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/dandelin/ViLT).
+
+## ViltConfig
+
+[[autodoc]] ViltConfig
+
+## ViltFeatureExtractor
+
+[[autodoc]] ViltFeatureExtractor
+ - __call__
+
+## ViltProcessor
+
+[[autodoc]] ViltProcessor
+ - __call__
+
+## ViltModel
+
+[[autodoc]] ViltModel
+ - forward
+
+## ViltForMaskedLM
+
+[[autodoc]] ViltForMaskedLM
+ - forward
+
+## ViltForQuestionAnswering
+
+[[autodoc]] ViltForQuestionAnswering
+ - forward
+
+## ViltForImagesAndTextClassification
+
+[[autodoc]] ViltForImagesAndTextClassification
+ - forward
+
+## ViltForImageAndTextRetrieval
+
+[[autodoc]] ViltForImageAndTextRetrieval
+ - forward
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 137f7ff939..5a2f0ac8f4 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -308,6 +308,7 @@ _import_structure = {
"UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"UniSpeechSatConfig",
],
+ "models.vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig", "ViltFeatureExtractor", "ViltProcessor"],
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"],
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
@@ -514,6 +515,8 @@ if is_vision_available():
_import_structure["models.layoutxlm"].append("LayoutXLMProcessor")
_import_structure["models.perceiver"].append("PerceiverFeatureExtractor")
_import_structure["models.segformer"].append("SegformerFeatureExtractor")
+ _import_structure["models.vilt"].append("ViltFeatureExtractor")
+ _import_structure["models.vilt"].append("ViltProcessor")
_import_structure["models.vit"].append("ViTFeatureExtractor")
else:
from .utils import dummy_vision_objects
@@ -629,7 +632,6 @@ if is_torch_available():
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure
-
_import_structure["models.albert"].extend(
[
"ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1382,6 +1384,18 @@ if is_torch_available():
"UniSpeechSatPreTrainedModel",
]
)
+ _import_structure["models.vilt"].extend(
+ [
+ "VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "ViltForImageAndTextRetrieval",
+ "ViltForImagesAndTextClassification",
+ "ViltForMaskedLM",
+ "ViltForQuestionAnswering",
+ "ViltLayer",
+ "ViltModel",
+ "ViltPreTrainedModel",
+ ]
+ )
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
_import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"])
_import_structure["models.visual_bert"].extend(
@@ -2409,6 +2423,7 @@ if TYPE_CHECKING:
from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
+ from .models.vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig, ViltFeatureExtractor, ViltProcessor
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
@@ -2585,6 +2600,7 @@ if TYPE_CHECKING:
from .models.layoutxlm import LayoutXLMProcessor
from .models.perceiver import PerceiverFeatureExtractor
from .models.segformer import SegformerFeatureExtractor
+ from .models.vilt import ViltFeatureExtractor, ViltProcessor
from .models.vit import ViTFeatureExtractor
else:
from .utils.dummy_vision_objects import *
@@ -3302,6 +3318,16 @@ if TYPE_CHECKING:
UniSpeechSatModel,
UniSpeechSatPreTrainedModel,
)
+ from .models.vilt import (
+ VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ ViltForImageAndTextRetrieval,
+ ViltForImagesAndTextClassification,
+ ViltForMaskedLM,
+ ViltForQuestionAnswering,
+ ViltLayer,
+ ViltModel,
+ ViltPreTrainedModel,
+ )
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
from .models.vision_text_dual_encoder import VisionTextDualEncoderModel
from .models.visual_bert import (
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index c7ad786d33..fde42fc078 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -104,6 +104,7 @@ from . import (
trocr,
unispeech,
unispeech_sat,
+ vilt,
vision_encoder_decoder,
vision_text_dual_encoder,
visual_bert,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 12e9159915..f5de25919a 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -30,6 +30,7 @@ logger = logging.get_logger(__name__)
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
+ ("vilt", "ViltConfig"),
("vit_mae", "ViTMAEConfig"),
("realm", "RealmConfig"),
("nystromformer", "NystromformerConfig"),
@@ -119,6 +120,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
+ ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -196,6 +198,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
+ ("vilt", "ViLT"),
("vit_mae", "ViTMAE"),
("realm", "Realm"),
("nystromformer", "Nystromformer"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index d73cdb7145..c76f9b84d4 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
+ ("vilt", "ViltModel"),
("vit_mae", "ViTMAEModel"),
("nystromformer", "NystromformerModel"),
("imagegpt", "ImageGPTModel"),
diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py
index 77d610f720..a5da154f99 100644
--- a/src/transformers/models/deit/modeling_deit.py
+++ b/src/transformers/models/deit/modeling_deit.py
@@ -297,12 +297,6 @@ class DeiTLayer(nn.Module):
# in DeiT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
-
- # TODO feedforward chunking not working for now
- # layer_output = apply_chunking_to_forward(
- # self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layer_output
- # )
-
layer_output = self.intermediate(layer_output)
# second residual connection is done here
@@ -312,11 +306,6 @@ class DeiTLayer(nn.Module):
return outputs
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output)
- return layer_output
-
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
class DeiTEncoder(nn.Module):
diff --git a/src/transformers/models/vilt/__init__.py b/src/transformers/models/vilt/__init__.py
new file mode 100644
index 0000000000..09e5b59c56
--- /dev/null
+++ b/src/transformers/models/vilt/__init__.py
@@ -0,0 +1,68 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+# rely on isort to merge the imports
+from ...file_utils import _LazyModule, is_torch_available, is_vision_available
+
+
+_import_structure = {
+ "configuration_vilt": ["VILT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViltConfig"],
+}
+
+if is_vision_available():
+ _import_structure["feature_extraction_vilt"] = ["ViltFeatureExtractor"]
+ _import_structure["processing_vilt"] = ["ViltProcessor"]
+
+if is_torch_available():
+ _import_structure["modeling_vilt"] = [
+ "VILT_PRETRAINED_MODEL_ARCHIVE_LIST",
+ "ViltForImageAndTextRetrieval",
+ "ViltForImagesAndTextClassification",
+ "ViltForMaskedLM",
+ "ViltForQuestionAnswering",
+ "ViltLayer",
+ "ViltModel",
+ "ViltPreTrainedModel",
+ ]
+
+
+if TYPE_CHECKING:
+ from .configuration_vilt import VILT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViltConfig
+
+ if is_vision_available():
+ from .feature_extraction_vilt import ViltFeatureExtractor
+ from .processing_vilt import ViltProcessor
+
+ if is_torch_available():
+ from .modeling_vilt import (
+ VILT_PRETRAINED_MODEL_ARCHIVE_LIST,
+ ViltForImageAndTextRetrieval,
+ ViltForImagesAndTextClassification,
+ ViltForMaskedLM,
+ ViltForQuestionAnswering,
+ ViltLayer,
+ ViltModel,
+ ViltPreTrainedModel,
+ )
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
diff --git a/src/transformers/models/vilt/configuration_vilt.py b/src/transformers/models/vilt/configuration_vilt.py
new file mode 100644
index 0000000000..5e1c40df5b
--- /dev/null
+++ b/src/transformers/models/vilt/configuration_vilt.py
@@ -0,0 +1,148 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" VilT model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VILT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ # TODO
+}
+
+
+class ViltConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ViLTModel`]. It is used to instantiate an ViLT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ViLT
+ [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the text part of the model. Defines the number of different tokens that can be
+ represented by the `inputs_ids` passed when calling [`ViltModel`].
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`ViltModel`]. This is used when encoding
+ text.
+ modality_type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the modalities passed when calling [`ViltModel`]. This is used after concatening the
+ embeddings of the text and image modalities.
+ max_position_embeddings (`int`, *optional*, defaults to 40):
+ The maximum sequence length that this model might ever be used with.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ max_image_length (`int`, *optional*, defaults to -1):
+ The maximum number of patches to take as input for the Transformer encoder. If set to a positive integer,
+ the encoder will sample `max_image_length` patches at maximum. If set to -1, will not be taken into
+ account.
+ num_images (`int`, *optional*, defaults to -1):
+ The number of images to use for natural language visual reasoning. If set to a positive integer, will be
+ used by [`ViltForImagesAndTextClassification`] for defining the classifier head.
+
+ Example:
+
+ ```python
+ >>> from transformers import ViLTModel, ViLTConfig
+
+ >>> # Initializing a ViLT dandelin/vilt-b32-mlm style configuration
+ >>> configuration = ViLTConfig()
+
+ >>> # Initializing a model from the dandelin/vilt-b32-mlm style configuration
+ >>> model = ViLTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "vilt"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ type_vocab_size=2,
+ modality_type_vocab_size=2,
+ max_position_embeddings=40,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ is_encoder_decoder=False,
+ image_size=384,
+ patch_size=32,
+ num_channels=3,
+ qkv_bias=True,
+ max_image_length=-1,
+ tie_word_embeddings=False,
+ num_images=-1,
+ **kwargs
+ ):
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.type_vocab_size = type_vocab_size
+ self.modality_type_vocab_size = modality_type_vocab_size
+ self.max_position_embeddings = max_position_embeddings
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.max_image_length = max_image_length
+ self.num_images = num_images
diff --git a/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py b/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
new file mode 100644
index 0000000000..5b68e330d2
--- /dev/null
+++ b/src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
@@ -0,0 +1,297 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert ViLT checkpoints from the original Github repository."""
+
+
+import argparse
+import json
+from pathlib import Path
+
+import torch
+from PIL import Image
+
+import requests
+from huggingface_hub import cached_download, hf_hub_url
+from transformers import (
+ BertTokenizer,
+ ViltConfig,
+ ViltFeatureExtractor,
+ ViltForImageAndTextRetrieval,
+ ViltForImagesAndTextClassification,
+ ViltForMaskedLM,
+ ViltForQuestionAnswering,
+ ViltProcessor,
+)
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config, vqa_model=False, nlvr_model=False, irtr_model=False):
+ rename_keys = []
+ for i in range(config.num_hidden_layers):
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
+ rename_keys.append((f"transformer.blocks.{i}.norm1.weight", f"vilt.encoder.layer.{i}.layernorm_before.weight"))
+ rename_keys.append((f"transformer.blocks.{i}.norm1.bias", f"vilt.encoder.layer.{i}.layernorm_before.bias"))
+ rename_keys.append(
+ (f"transformer.blocks.{i}.attn.proj.weight", f"vilt.encoder.layer.{i}.attention.output.dense.weight")
+ )
+ rename_keys.append(
+ (f"transformer.blocks.{i}.attn.proj.bias", f"vilt.encoder.layer.{i}.attention.output.dense.bias")
+ )
+ rename_keys.append((f"transformer.blocks.{i}.norm2.weight", f"vilt.encoder.layer.{i}.layernorm_after.weight"))
+ rename_keys.append((f"transformer.blocks.{i}.norm2.bias", f"vilt.encoder.layer.{i}.layernorm_after.bias"))
+ rename_keys.append(
+ (f"transformer.blocks.{i}.mlp.fc1.weight", f"vilt.encoder.layer.{i}.intermediate.dense.weight")
+ )
+ rename_keys.append((f"transformer.blocks.{i}.mlp.fc1.bias", f"vilt.encoder.layer.{i}.intermediate.dense.bias"))
+ rename_keys.append((f"transformer.blocks.{i}.mlp.fc2.weight", f"vilt.encoder.layer.{i}.output.dense.weight"))
+ rename_keys.append((f"transformer.blocks.{i}.mlp.fc2.bias", f"vilt.encoder.layer.{i}.output.dense.bias"))
+
+ # embeddings
+ rename_keys.extend(
+ [
+ # text embeddings
+ ("text_embeddings.word_embeddings.weight", "vilt.embeddings.text_embeddings.word_embeddings.weight"),
+ (
+ "text_embeddings.position_embeddings.weight",
+ "vilt.embeddings.text_embeddings.position_embeddings.weight",
+ ),
+ ("text_embeddings.position_ids", "vilt.embeddings.text_embeddings.position_ids"),
+ (
+ "text_embeddings.token_type_embeddings.weight",
+ "vilt.embeddings.text_embeddings.token_type_embeddings.weight",
+ ),
+ ("text_embeddings.LayerNorm.weight", "vilt.embeddings.text_embeddings.LayerNorm.weight"),
+ ("text_embeddings.LayerNorm.bias", "vilt.embeddings.text_embeddings.LayerNorm.bias"),
+ # patch embeddings
+ ("transformer.cls_token", "vilt.embeddings.cls_token"),
+ ("transformer.patch_embed.proj.weight", "vilt.embeddings.patch_embeddings.projection.weight"),
+ ("transformer.patch_embed.proj.bias", "vilt.embeddings.patch_embeddings.projection.bias"),
+ ("transformer.pos_embed", "vilt.embeddings.position_embeddings"),
+ # token type embeddings
+ ("token_type_embeddings.weight", "vilt.embeddings.token_type_embeddings.weight"),
+ ]
+ )
+
+ # final layernorm + pooler
+ rename_keys.extend(
+ [
+ ("transformer.norm.weight", "vilt.layernorm.weight"),
+ ("transformer.norm.bias", "vilt.layernorm.bias"),
+ ("pooler.dense.weight", "vilt.pooler.dense.weight"),
+ ("pooler.dense.bias", "vilt.pooler.dense.bias"),
+ ]
+ )
+
+ # classifier head(s)
+ if vqa_model:
+ # classification head
+ rename_keys.extend(
+ [
+ ("vqa_classifier.0.weight", "classifier.0.weight"),
+ ("vqa_classifier.0.bias", "classifier.0.bias"),
+ ("vqa_classifier.1.weight", "classifier.1.weight"),
+ ("vqa_classifier.1.bias", "classifier.1.bias"),
+ ("vqa_classifier.3.weight", "classifier.3.weight"),
+ ("vqa_classifier.3.bias", "classifier.3.bias"),
+ ]
+ )
+ elif nlvr_model:
+ # classification head
+ rename_keys.extend(
+ [
+ ("nlvr2_classifier.0.weight", "classifier.0.weight"),
+ ("nlvr2_classifier.0.bias", "classifier.0.bias"),
+ ("nlvr2_classifier.1.weight", "classifier.1.weight"),
+ ("nlvr2_classifier.1.bias", "classifier.1.bias"),
+ ("nlvr2_classifier.3.weight", "classifier.3.weight"),
+ ("nlvr2_classifier.3.bias", "classifier.3.bias"),
+ ]
+ )
+ else:
+ pass
+
+ return rename_keys
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ prefix = "vilt."
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"transformer.blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"transformer.blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
+ : config.hidden_size, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
+ -config.hidden_size :, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+def remove_classification_head_(state_dict):
+ ignore_keys = ["head.weight", "head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+@torch.no_grad()
+def convert_vilt_checkpoint(checkpoint_url, pytorch_dump_folder_path):
+ """
+ Copy/paste/tweak model's weights to our ViLT structure.
+ """
+
+ # define configuration and initialize HuggingFace model
+ config = ViltConfig(image_size=384, patch_size=32, tie_word_embeddings=False)
+ mlm_model = False
+ vqa_model = False
+ nlvr_model = False
+ irtr_model = False
+ if "vqa" in checkpoint_url:
+ vqa_model = True
+ config.num_labels = 3129
+ repo_id = "datasets/huggingface/label-files"
+ filename = "vqa2-id2label.json"
+ id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+ model = ViltForQuestionAnswering(config)
+ elif "nlvr" in checkpoint_url:
+ nlvr_model = True
+ config.num_labels = 2
+ config.id2label = {0: "False", 1: "True"}
+ config.label2id = {v: k for k, v in config.id2label.items()}
+ config.modality_type_vocab_size = 3
+ model = ViltForImagesAndTextClassification(config)
+ elif "irtr" in checkpoint_url:
+ irtr_model = True
+ model = ViltForImageAndTextRetrieval(config)
+ elif "mlm_itm" in checkpoint_url:
+ mlm_model = True
+ model = ViltForMaskedLM(config)
+ else:
+ raise ValueError("Unknown model type")
+
+ # load state_dict of original model, remove and rename some keys
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["state_dict"]
+ rename_keys = create_rename_keys(config, vqa_model, nlvr_model, irtr_model)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config)
+ if mlm_model or irtr_model:
+ ignore_keys = ["itm_score.fc.weight", "itm_score.fc.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+ # load state dict into HuggingFace model
+ model.eval()
+ if mlm_model:
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ assert missing_keys == ["mlm_score.decoder.bias"]
+ else:
+ model.load_state_dict(state_dict)
+
+ # Define processor
+ feature_extractor = ViltFeatureExtractor(size=384)
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+ processor = ViltProcessor(feature_extractor, tokenizer)
+
+ # Forward pass on example inputs (image + text)
+ if nlvr_model:
+ image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
+ image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
+ text = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
+ encoding_1 = processor(image1, text, return_tensors="pt")
+ encoding_2 = processor(image2, text, return_tensors="pt")
+ outputs = model(
+ input_ids=encoding_1.input_ids,
+ pixel_values=encoding_1.pixel_values,
+ pixel_values_2=encoding_2.pixel_values,
+ )
+ else:
+ image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
+ if mlm_model:
+ text = "a bunch of [MASK] laying on a [MASK]."
+ else:
+ text = "How many cats are there?"
+ encoding = processor(image, text, return_tensors="pt")
+ outputs = model(**encoding)
+
+ # Verify outputs
+ if mlm_model:
+ expected_shape = torch.Size([1, 11, 30522])
+ expected_slice = torch.tensor([-12.5061, -12.5123, -12.5174])
+ assert outputs.logits.shape == expected_shape
+ assert torch.allclose(outputs.logits[0, 0, :3], expected_slice, atol=1e-4)
+
+ # verify masked token prediction equals "cats"
+ predicted_id = outputs.logits[0, 4, :].argmax(-1).item()
+ assert tokenizer.decode([predicted_id]) == "cats"
+ elif vqa_model:
+ expected_shape = torch.Size([1, 3129])
+ expected_slice = torch.tensor([-15.9495, -18.1472, -10.3041])
+ assert torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)
+ assert outputs.logits.shape == expected_shape
+ assert torch.allclose(outputs.logits[0, 0, :3], expected_slice, atol=1e-4)
+
+ # verify vqa prediction equals "2"
+ predicted_idx = outputs.logits.argmax(-1).item()
+ assert model.config.id2label[predicted_idx] == "2"
+ elif nlvr_model:
+ expected_shape = torch.Size([1, 2])
+ expected_slice = torch.tensor([-2.8721, 2.1291])
+ assert torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)
+ assert outputs.logits.shape == expected_shape
+
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://github.com/dandelin/ViLT/releases/download/200k/vilt_200k_mlm_itm.ckpt",
+ type=str,
+ help="URL of the checkpoint you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+
+ args = parser.parse_args()
+ convert_vilt_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
diff --git a/src/transformers/models/vilt/feature_extraction_vilt.py b/src/transformers/models/vilt/feature_extraction_vilt.py
new file mode 100644
index 0000000000..344bf98270
--- /dev/null
+++ b/src/transformers/models/vilt/feature_extraction_vilt.py
@@ -0,0 +1,292 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for ViLT."""
+
+from typing import List, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from ...file_utils import TensorType, is_torch_available
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageFeatureExtractionMixin,
+ ImageInput,
+ is_torch_tensor,
+)
+from ...utils import logging
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class ViltFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
+ r"""
+ Constructs a ViLT feature extractor.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the input based on `size`.
+ size (`int`, *optional*, defaults to 384):
+ Resize the shorter side of the input to the given size. Should be an integer. The longer side will be
+ limited to under int((1333 / 800) * size) while preserving the aspect ratio. Only has an effect if
+ `do_resize` is set to `True`.
+ size_divisor (`int`, *optional*, defaults to 32):
+ The size by which to make sure both the height and width can be divided.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
+ `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
+ if `do_resize` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether or not to normalize the input with mean and standard deviation.
+ image_mean (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of means for each channel, to be used when normalizing images.
+ image_std (`List[int]`, defaults to `[0.5, 0.5, 0.5]`):
+ The sequence of standard deviations for each channel, to be used when normalizing images.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ def __init__(
+ self,
+ do_resize=True,
+ size=384,
+ size_divisor=32,
+ resample=Image.BICUBIC,
+ do_normalize=True,
+ image_mean=None,
+ image_std=None,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.size = size
+ self.size_divisor = size_divisor
+ self.resample = resample
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def _resize(self, image, shorter=800, longer=1333, size_divisor=32, resample=Image.BICUBIC):
+ """
+ Resizes the shorter edge of `image` to `shorter` and limits the longer edge to under `longer`, while preserving
+ the aspect ratio. Also makes sure that both the height and width can be divided by `size_divisor`.
+
+ Based on original implementation:
+ https://github.com/dandelin/ViLT/blob/3db8b5035464afee84d951bf6322e1b27f1d072d/vilt/transforms/utils.py#L5
+
+ Args:
+ image (`PIL.Image`):
+ The image to resize.
+ shorter (`int`, *optional*, defaults to `800`):
+ The size to which to resize the shorter side of the image.
+ longer (`int`, *optional*, defaults to `1333`):
+ The size by which to limit the longer side of the image, while preserving the aspect ratio.
+ size_divisor (`int`, *optional*, defaults to `32`):
+ The size by which both the height and the width must be divisible.
+ resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
+ An optional resampling filter.
+ """
+ if not isinstance(image, Image.Image):
+ image = self.to_pil_image(image)
+
+ w, h = image.size
+ min_size = shorter
+ max_size = longer
+ scale = min_size / min(w, h)
+ if h < w:
+ newh, neww = min_size, scale * w
+ else:
+ newh, neww = scale * h, min_size
+
+ if max(newh, neww) > max_size:
+ scale = max_size / max(newh, neww)
+ newh = newh * scale
+ neww = neww * scale
+
+ newh, neww = int(newh + 0.5), int(neww + 0.5)
+ newh, neww = newh // size_divisor * size_divisor, neww // size_divisor * size_divisor
+
+ return self.resize(image, size=(neww, newh), resample=resample)
+
+ def _max_by_axis(self, the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+ def pad_and_create_pixel_mask(
+ self, pixel_values_list: List["torch.Tensor"], return_tensors: Optional[Union[str, TensorType]] = None
+ ):
+ """
+ Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
+
+ Args:
+ pixel_values_list (`List[torch.Tensor]`):
+ List of images (pixel values) to be padded. Each image should be a tensor of shape (C, H, W).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
+ objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model.
+ - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
+ *"pixel_mask"* is in `self.model_input_names`).
+ """
+
+ max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
+ c, h, w = max_size
+ padded_images = []
+ pixel_mask = []
+ for image in pixel_values_list:
+ # create padded image
+ padded_image = np.zeros((c, h, w), dtype=np.float32)
+ padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
+ padded_images.append(padded_image)
+ # create pixel mask
+ mask = np.zeros((h, w), dtype=np.int64)
+ mask[: image.shape[1], : image.shape[2]] = True
+ pixel_mask.append(mask)
+
+ # return as BatchFeature
+ data = {"pixel_values": padded_images, "pixel_mask": pixel_mask}
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
+
+ def __call__(
+ self,
+ images: ImageInput,
+ pad_and_return_pixel_mask: Optional[bool] = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several image(s).
+
+
+
+ NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
+ PIL images.
+
+
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
+ number of channels, H and W are image height and width.
+
+ pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether or not to pad images up to the largest image in a batch and create a pixel mask.
+
+ If left to the default, will return a pixel mask that is:
+
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+ - **pixel_mask** -- Pixel mask to be fed to a model (when `return_pixel_mask=True` or if *"pixel_mask"* is
+ in `self.model_input_names`).
+ """
+ # Input type checking for clearer error
+ valid_images = False
+
+ # Check that images has a valid type
+ if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
+ valid_images = True
+ elif isinstance(images, (list, tuple)):
+ if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
+ valid_images = True
+
+ if not valid_images:
+ raise ValueError(
+ "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
+ "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
+ )
+
+ is_batched = bool(
+ isinstance(images, (list, tuple))
+ and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
+ )
+
+ if not is_batched:
+ images = [images]
+
+ # transformations (resizing + normalization)
+ if self.do_resize and self.size is not None:
+ longer = int((1333 / 800) * self.size)
+ images = [
+ self._resize(
+ image=image,
+ shorter=self.size,
+ longer=longer,
+ size_divisor=self.size_divisor,
+ resample=self.resample,
+ )
+ for image in images
+ ]
+ if self.do_normalize:
+ images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
+
+ if pad_and_return_pixel_mask:
+ # pad images up to largest image in batch and create pixel_mask
+ max_size = self._max_by_axis([list(image.shape) for image in images])
+ c, h, w = max_size
+ padded_images = []
+ pixel_mask = []
+ for image in images:
+ # create padded image
+ padded_image = np.zeros((c, h, w), dtype=np.float32)
+ padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
+ padded_images.append(padded_image)
+ # create pixel mask
+ mask = np.zeros((h, w), dtype=np.int64)
+ mask[: image.shape[1], : image.shape[2]] = True
+ pixel_mask.append(mask)
+ images = padded_images
+
+ # return as BatchFeature
+ data = {}
+ data["pixel_values"] = images
+ if pad_and_return_pixel_mask:
+ data["pixel_mask"] = pixel_mask
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ return encoded_inputs
diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py
new file mode 100755
index 0000000000..28343fe791
--- /dev/null
+++ b/src/transformers/models/vilt/modeling_vilt.py
@@ -0,0 +1,1366 @@
+# coding=utf-8
+# Copyright 2022 NAVER AI Labs and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch ViLT model."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from packaging import version
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ MaskedLMOutput,
+ ModelOutput,
+ SequenceClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import logging
+from .configuration_vilt import ViltConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "ViltConfig"
+_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm-itm"
+
+VILT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "dandelin/vilt-b32-mlm-itm",
+ # See all ViLT models at https://huggingface.co/models?filter=vilt
+]
+
+
+@dataclass
+class ViltForImagesAndTextClassificationOutput(ModelOutput):
+ """
+ Class for outputs of [`ViltForImagesAndTextClassification`].
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the output of
+ the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the attention
+ weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the
+ attention softmax, used to compute the weighted average in the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ hidden_states: Optional[List[Tuple[torch.FloatTensor]]] = None
+ attentions: Optional[List[Tuple[torch.FloatTensor]]] = None
+
+
+# Copied from transformers.models.vit.modeling_vit.to_2tuple
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+class ViltEmbeddings(nn.Module):
+ """
+ Construct the text and patch embeddings.
+
+ Text embeddings are equivalent to BERT embeddings.
+
+ Patch embeddings are equivalent to ViT embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ # text embeddings
+ self.text_embeddings = TextEmbeddings(config)
+ # patch embeddings
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.patch_embeddings = PatchEmbeddings(
+ image_size=config.image_size,
+ patch_size=config.patch_size,
+ num_channels=config.num_channels,
+ embed_dim=config.hidden_size,
+ )
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ # modality type (text/patch) embeddings
+ self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
+ _, _, ph, pw = self.patch_embeddings.projection.weight.shape
+
+ x = self.patch_embeddings(pixel_values)
+ x_mask = pixel_mask[:, None, :, :].float()
+ x_mask = nn.functional.interpolate(x_mask, size=(x.shape[2], x.shape[3])).long()
+ x_h = x_mask[:, 0].sum(dim=1)[:, 0]
+ x_w = x_mask[:, 0].sum(dim=2)[:, 0]
+
+ batch_size, num_channels, height, width = x.shape
+ patch_dim = self.config.image_size // self.config.patch_size
+ spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(1, num_channels, patch_dim, patch_dim)
+ pos_embed = torch.cat(
+ [
+ nn.functional.pad(
+ nn.functional.interpolate(
+ spatial_pos,
+ size=(h, w),
+ mode="bilinear",
+ align_corners=True,
+ ),
+ (0, width - w, 0, height - h),
+ )
+ for h, w in zip(x_h, x_w)
+ ],
+ dim=0,
+ )
+
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ x = x.flatten(2).transpose(1, 2)
+ patch_index = torch.stack(
+ torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
+ )
+ patch_index = patch_index[None, None, :, :, :]
+ patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)
+ patch_index = patch_index.flatten(1, 3)
+ x_mask = x_mask.flatten(1)
+
+ if max_image_length < 0 or max_image_length is None or not isinstance(max_image_length, int):
+ # suppose aug is 800 x 1333, then, maximum effective res is 800 x 1333 (if one side gets bigger, the other will be constrained and be shrinked)
+ # (800 // self.patch_size) * (1333 // self.patch_size) is the maximum number of patches that single image can get.
+ # if self.patch_size = 32, 25 * 41 = 1025
+ # if res is 384 x 640, 12 * 20 = 240
+ effective_resolution = x_h * x_w
+ max_image_length = effective_resolution.max()
+ else:
+ effective_resolution = x_h * x_w
+ max_image_length = min(effective_resolution.max(), max_image_length)
+
+ valid_idx = x_mask.nonzero(as_tuple=False)
+ non_valid_idx = (1 - x_mask).nonzero(as_tuple=False)
+ unique_rows = valid_idx[:, 0].unique()
+ valid_row_idx = [valid_idx[valid_idx[:, 0] == u] for u in unique_rows]
+ non_valid_row_idx = [non_valid_idx[non_valid_idx[:, 0] == u] for u in unique_rows]
+
+ valid_nums = [v.size(0) for v in valid_row_idx]
+ non_valid_nums = [v.size(0) for v in non_valid_row_idx]
+ pad_nums = [max_image_length - v for v in valid_nums]
+
+ select = list()
+ for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)):
+ if p <= 0:
+ valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length)
+ select.append(valid_row_idx[i][valid_choice])
+ else:
+ pad_choice = torch.multinomial(torch.ones(nv).float(), p, replacement=True)
+ select.append(torch.cat([valid_row_idx[i], non_valid_row_idx[i][pad_choice]], dim=0))
+
+ select = torch.cat(select, dim=0)
+ x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
+ x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1)
+ patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2)
+ pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ pos_embed = torch.cat(
+ (self.position_embeddings[:, 0, :][:, None, :].expand(batch_size, -1, -1), pos_embed), dim=1
+ )
+ x = x + pos_embed
+ x = self.dropout(x)
+
+ x_mask = torch.cat([torch.ones(x_mask.shape[0], 1).to(x_mask), x_mask], dim=1)
+
+ return x, x_mask, (patch_index, (height, width))
+
+ def forward(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ pixel_values,
+ pixel_mask,
+ inputs_embeds,
+ image_embeds,
+ image_token_type_idx=1,
+ ):
+ # PART 1: text embeddings
+ text_embeds = self.text_embeddings(
+ input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
+ )
+
+ # PART 2: patch embeddings (with interpolated position encodings)
+ if image_embeds is None:
+ image_embeds, image_masks, patch_index = self.visual_embed(
+ pixel_values, pixel_mask, max_image_length=self.config.max_image_length
+ )
+ else:
+ image_masks = pixel_mask.flatten(1)
+
+ # PART 3: add modality type embeddings
+ # 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2)
+ if image_token_type_idx is None:
+ image_token_type_idx = 1
+ text_embeds = text_embeds + self.token_type_embeddings(
+ torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
+ )
+ image_embeds = image_embeds + self.token_type_embeddings(
+ torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device)
+ )
+
+ # PART 4: concatenate
+ embeddings = torch.cat([text_embeds, image_embeds], dim=1)
+ masks = torch.cat([attention_mask, image_masks], dim=1)
+
+ return embeddings, masks
+
+
+class TextEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
+ persistent=False,
+ )
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, :seq_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Based on timm implementation, which can be found here:
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+class PatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
+ super().__init__()
+ image_size = to_2tuple(image_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values):
+ batch_size, num_channels, height, width = pixel_values.shape
+ x = self.projection(pixel_values)
+ return x
+
+
+class ViltSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vilt
+class ViltSelfOutput(nn.Module):
+ """
+ The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class ViltAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = ViltSelfAttention(config)
+ self.output = ViltSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ self_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
+class ViltIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
+class ViltOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+class ViltLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = ViltAttention(config)
+ self.intermediate = ViltIntermediate(config)
+ self.output = ViltOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViLT, layernorm is applied before self-attention
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in ViLT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class ViltEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ViltPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ViltConfig
+ base_model_prefix = "vilt"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, ViltEncoder):
+ module.gradient_checkpointing = value
+
+
+VILT_START_DOCSTRING = r"""
+ This model is a PyTorch `torch.nn.Module `_ subclass. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ViltConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VILT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`ViltFeatureExtractor`]. See
+ [`ViltFeatureExtractor.__call__`] for details.
+
+ pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
+
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+ `What are attention masks? <../glossary.html#attention-mask>`__
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`BertTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
+ IDs?](../glossary#input-ids)
+
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+ [What are token type IDs?](../glossary#token-type-ids)
+
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_images, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`ViltFeatureExtractor`]. See
+ [`ViltFeatureExtractor.__call__`] for details.
+
+ pixel_mask (`torch.LongTensor` of shape `(batch_size, num_images, height, width)`, *optional*):
+ Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
+
+ - 1 for pixels that are real (i.e. **not masked**),
+ - 0 for pixels that are padding (i.e. **masked**).
+ `What are attention masks? <../glossary.html#attention-mask>`__
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
+ Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ViLT Model transformer outputting raw hidden-states without any specific head on top.",
+ VILT_START_DOCSTRING,
+)
+class ViltModel(ViltPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ViltEmbeddings(config)
+ self.encoder = ViltEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = ViltPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.text_embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.text_embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ pixel_values=None,
+ pixel_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ image_embeds=None,
+ image_token_type_idx=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ViltFeatureExtractor, ViltModel
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> feature_extractor = ViltFeatureExtractor.from_pretrained("dandelin/vilt-b32-mlm-itm")
+ >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm-itm")
+
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ batch_size, num_channels, height, width = pixel_values.shape
+ if pixel_mask is None:
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output, attention_mask = self.embeddings(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ pixel_values,
+ pixel_mask,
+ inputs_embeds,
+ image_embeds,
+ image_token_type_idx=image_token_type_idx,
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class ViltPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@add_start_docstrings(
+ """
+ ViLT Model with a language modeling head on top as done during pretraining.
+ """,
+ VILT_START_DOCSTRING,
+)
+class ViltForMaskedLM(ViltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.vilt = ViltModel(config)
+ self.mlm_score = ViltMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.mlm_score.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.mlm_score.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ pixel_values=None,
+ pixel_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ image_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
+ config.vocab_size]* (see *input_ids* docstring) Tokens with indices set to *-100* are ignored (masked), the
+ loss is only computed for the tokens with labels in *[0, ..., config.vocab_size]*
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ViltProcessor, ViltForMaskedLM
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> text = "How many cats are there?"
+
+ >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
+ >>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
+
+ >>> # prepare inputs
+ >>> encoding = processor(image, text, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**encoding)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vilt(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ # split up final hidden states into text and image features
+ text_seq_len = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+ text_features, _ = (sequence_output[:, :text_seq_len], sequence_output[:, text_seq_len:])
+
+ mlm_logits = self.mlm_score(text_features)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (mlm_logits,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=mlm_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class ViltPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class ViltMLMHead(nn.Module):
+ def __init__(self, config, weight=None):
+ super().__init__()
+ self.config = config
+ self.transform = ViltPredictionHeadTransform(config)
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ if weight is not None:
+ self.decoder.weight = weight
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, x):
+ x = self.transform(x)
+ x = self.decoder(x)
+ return x
+
+
+@add_start_docstrings(
+ """
+ Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
+ token) for visual question answering, e.g. for VQAv2.
+ """,
+ VILT_START_DOCSTRING,
+)
+class ViltForQuestionAnswering(ViltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.vilt = ViltModel(config)
+
+ # Classifier head
+ self.classifier = nn.Sequential(
+ nn.Linear(config.hidden_size, config.hidden_size * 2),
+ nn.LayerNorm(config.hidden_size * 2),
+ nn.GELU(),
+ nn.Linear(config.hidden_size * 2, config.num_labels),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ pixel_values=None,
+ pixel_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ image_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
+ Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
+ all answers that are applicable for a given example in the batch, or a soft encoding indicating which
+ answers are applicable, where 1.0 is the highest score.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ViltProcessor, ViltForQuestionAnswering
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> text = "How many cats are there?"
+
+ >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
+ >>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
+
+ >>> # prepare inputs
+ >>> encoding = processor(image, text, return_tensors="pt")
+
+ >>> # forward pass
+ >>> outputs = model(**encoding)
+ >>> logits = outputs.logits
+ >>> idx = logits.argmax(-1).item()
+ >>> print("Predicted answer:", model.config.id2label[idx])
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vilt(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooler_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooler_output)
+
+ loss = None
+ if labels is not None:
+ loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]
+ # see https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
+ token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
+ """,
+ VILT_START_DOCSTRING,
+)
+class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.vilt = ViltModel(config)
+
+ # Classifier head
+ self.rank_output = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ pixel_values=None,
+ pixel_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ image_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels are currently not supported.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
+
+ >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
+ >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
+
+ >>> # prepare inputs
+ >>> encoding = processor(image, text, return_tensors="pt")
+
+ >>> # forward pass
+ >>> scores = dict()
+ >>> for text in texts:
+ ... encoding = processor(image, text, return_tensors="pt")
+ ... outputs = model(**encoding)
+ ... scores[text] = outputs.logits[0, :].item()
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vilt(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values,
+ pixel_mask=pixel_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooler_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.rank_output(pooler_output)
+
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not yet supported.")
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
+ """,
+ VILT_IMAGES_AND_TEXT_CLASSIFICATION_INPUTS_DOCSTRING,
+)
+class ViltForImagesAndTextClassification(ViltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.vilt = ViltModel(config)
+
+ # Classifier head
+ num_images = config.num_images
+ self.classifier = nn.Sequential(
+ nn.Linear(config.hidden_size * num_images, config.hidden_size * num_images),
+ nn.LayerNorm(config.hidden_size * num_images),
+ nn.GELU(),
+ nn.Linear(config.hidden_size * num_images, config.num_labels),
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=ViltForImagesAndTextClassificationOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ pixel_values=None,
+ pixel_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ image_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Binary classification labels.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ViltProcessor, ViltForImagesAndTextClassification
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
+ >>> image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg", stream=True).raw)
+ >>> text = "The left image contains twice the number of dogs as the right image."
+
+ >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
+ >>> model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
+
+ >>> # prepare inputs
+ >>> encoding = processor([image1, image2], text, return_tensors="pt")
+ >>> pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1)
+
+ >>> # forward pass
+ >>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
+ >>> logits = outputs.logits
+ >>> idx = logits.argmax(-1).item()
+ >>> print("Predicted answer:", model.config.id2label[idx])
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values.ndim == 4:
+ # add dummy num_images dimension
+ pixel_values = pixel_values.unsqueeze(1)
+
+ num_images = pixel_values.shape[1]
+ if num_images != self.config.num_images:
+ raise ValueError(
+ "Make sure to match the number of images in the model with the number of images in the input."
+ )
+ pooler_outputs = []
+ hidden_states = [] if output_hidden_states else None
+ attentions = [] if output_attentions else None
+ for i in range(num_images):
+ # forward every image through the model
+ outputs = self.vilt(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ pixel_values=pixel_values[:, i, :, :, :],
+ pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ image_embeds=image_embeds,
+ image_token_type_idx=i + 1,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooler_output = outputs.pooler_output if return_dict else outputs[1]
+ pooler_outputs.append(pooler_output)
+ if output_hidden_states:
+ hidden_states.append(outputs.hidden_states)
+ if output_attentions:
+ attentions.append(outputs.attentions)
+
+ pooled_output = torch.cat(pooler_outputs, dim=-1)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits, hidden_states, attentions)
+ return ((loss,) + output) if loss is not None else output
+
+ return ViltForImagesAndTextClassificationOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
diff --git a/src/transformers/models/vilt/processing_vilt.py b/src/transformers/models/vilt/processing_vilt.py
new file mode 100644
index 0000000000..50ca918065
--- /dev/null
+++ b/src/transformers/models/vilt/processing_vilt.py
@@ -0,0 +1,172 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for ViLT.
+"""
+
+from typing import List, Optional, Union
+
+from transformers import BertTokenizerFast
+
+from ...file_utils import TensorType
+from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from .feature_extraction_vilt import ViltFeatureExtractor
+
+
+class ViltProcessor:
+ r"""
+ Constructs a ViLT processor which wraps a BERT tokenizer and ViLT feature extractor into a single processor.
+
+ [`ViltProcessor`] offers all the functionalities of [`ViltFeatureExtractor`] and [`BertTokenizerFast`]. See the
+ docstring of [`~ViltProcessor.__call__`] and [`~ViltProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`ViltFeatureExtractor`):
+ An instance of [`ViltFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`BertTokenizerFast`):
+ An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
+ """
+
+ def __init__(self, feature_extractor, tokenizer):
+ if not isinstance(feature_extractor, ViltFeatureExtractor):
+ raise ValueError(
+ f"`feature_extractor` has to be of type {ViltFeatureExtractor.__class__}, but is {type(feature_extractor)}"
+ )
+ if not isinstance(tokenizer, BertTokenizerFast):
+ raise ValueError(f"`tokenizer` has to be of type {BertTokenizerFast.__class__}, but is {type(tokenizer)}")
+
+ self.feature_extractor = feature_extractor
+ self.tokenizer = tokenizer
+ self.current_processor = self.feature_extractor
+
+ def save_pretrained(self, save_directory):
+ """
+ Save a ViLT feature_extractor object and BERT tokenizer object to the directory `save_directory`, so that it
+ can be re-loaded using the [`~ViltProcessor.from_pretrained`] class method.
+
+
+
+ This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
+ [`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods
+ above for more information.
+
+
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
+ be created if it does not exist).
+ """
+
+ self.feature_extractor.save_pretrained(save_directory)
+ self.tokenizer.save_pretrained(save_directory)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ r"""
+ Instantiate a [`ViltProcessor`] from a pretrained ViLT processor.
+
+
+
+ This class method is simply calling ViltFeatureExtractor's
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and BertTokenizerFast's
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`]. Please refer to the docstrings of the methods
+ above for more information.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
+ namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a feature extractor file saved using the
+ [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ **kwargs
+ Additional keyword arguments passed along to both [`SequenceFeatureExtractor`] and
+ [`PreTrainedTokenizer`]
+ """
+ feature_extractor = ViltFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs)
+
+ return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
+
+ def __call__(
+ self,
+ images,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = False,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs
+ ) -> BatchEncoding:
+ """
+ This method uses [`ViltFeatureExtractor.__call__`] method to prepare image(s) for the model, and
+ [`BertTokenizerFast.__call__`] to prepare text for the model.
+
+ Please refer to the docstring of the above two methods for more information.
+ """
+ encoding = self.tokenizer(
+ text=text,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+ # add pixel_values + pixel_mask
+ encoding_feature_extractor = self.feature_extractor(images, return_tensors=return_tensors)
+ encoding.update(encoding_feature_extractor)
+
+ return encoding
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py
index 27530ef8a0..8856c34a92 100644
--- a/src/transformers/models/vit/modeling_vit.py
+++ b/src/transformers/models/vit/modeling_vit.py
@@ -326,12 +326,6 @@ class ViTLayer(nn.Module):
# in ViT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(hidden_states)
-
- # TODO feedforward chunking not working for now
- # layer_output = apply_chunking_to_forward(
- # self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layer_output
- # )
-
layer_output = self.intermediate(layer_output)
# second residual connection is done here
@@ -341,11 +335,6 @@ class ViTLayer(nn.Module):
return outputs
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output)
- return layer_output
-
class ViTEncoder(nn.Module):
def __init__(self, config):
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index 04e398fd4c..01f3c0bc62 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -3540,6 +3540,58 @@ class UniSpeechSatPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
+VILT_PRETRAINED_MODEL_ARCHIVE_LIST = None
+
+
+class ViltForImageAndTextRetrieval(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViltForImagesAndTextClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViltForMaskedLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViltForQuestionAnswering(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViltLayer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViltModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class ViltPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class VisionEncoderDecoderModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index b7408bf3db..9e5efabc06 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -87,6 +87,20 @@ class SegformerFeatureExtractor(metaclass=DummyObject):
requires_backends(self, ["vision"])
+class ViltFeatureExtractor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
+class ViltProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class ViTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/test_feature_extraction_vilt.py b/tests/test_feature_extraction_vilt.py
new file mode 100644
index 0000000000..350c1758ff
--- /dev/null
+++ b/tests/test_feature_extraction_vilt.py
@@ -0,0 +1,251 @@
+# coding=utf-8
+# Copyright 2021 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+import numpy as np
+
+from transformers.file_utils import is_torch_available, is_vision_available
+from transformers.testing_utils import require_torch, require_vision
+
+from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import ViltFeatureExtractor
+
+
+class ViltFeatureExtractionTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ do_resize=True,
+ size=30,
+ size_divisor=2,
+ do_normalize=True,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.do_resize = do_resize
+ self.size = size
+ self.size_divisor = size_divisor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+
+ def prepare_feat_extract_dict(self):
+ return {
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_normalize": self.do_normalize,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "size_divisor": self.size_divisor,
+ }
+
+ def get_expected_values(self, image_inputs, batched=False):
+ """
+ This function computes the expected height and width when providing images to ViltFeatureExtractor,
+ assuming do_resize is set to True with a scalar size and size_divisor.
+ """
+ if not batched:
+ image = image_inputs[0]
+ if isinstance(image, Image.Image):
+ w, h = image.size
+ else:
+ h, w = image.shape[1], image.shape[2]
+ scale = self.size / min(w, h)
+ if h < w:
+ newh, neww = self.size, scale * w
+ else:
+ newh, neww = scale * h, self.size
+
+ max_size = int((1333 / 800) * self.size)
+ if max(newh, neww) > max_size:
+ scale = max_size / max(newh, neww)
+ newh = newh * scale
+ neww = neww * scale
+
+ newh, neww = int(newh + 0.5), int(neww + 0.5)
+ expected_height, expected_width = (
+ newh // self.size_divisor * self.size_divisor,
+ neww // self.size_divisor * self.size_divisor,
+ )
+
+ else:
+ expected_values = []
+ for image in image_inputs:
+ expected_height, expected_width = self.get_expected_values([image])
+ expected_values.append((expected_height, expected_width))
+ expected_height = max(expected_values, key=lambda item: item[0])[0]
+ expected_width = max(expected_values, key=lambda item: item[1])[1]
+
+ return expected_height, expected_width
+
+
+@require_torch
+@require_vision
+class ViltFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
+
+ feature_extraction_class = ViltFeatureExtractor if is_vision_available() else None
+
+ def setUp(self):
+ self.feature_extract_tester = ViltFeatureExtractionTester(self)
+
+ @property
+ def feat_extract_dict(self):
+ return self.feature_extract_tester.prepare_feat_extract_dict()
+
+ def test_feat_extract_properties(self):
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ self.assertTrue(hasattr(feature_extractor, "image_mean"))
+ self.assertTrue(hasattr(feature_extractor, "image_std"))
+ self.assertTrue(hasattr(feature_extractor, "do_normalize"))
+ self.assertTrue(hasattr(feature_extractor, "do_resize"))
+ self.assertTrue(hasattr(feature_extractor, "size"))
+ self.assertTrue(hasattr(feature_extractor, "size_divisor"))
+
+ def test_batch_feature(self):
+ pass
+
+ def test_call_pil(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PIL images
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
+ for image in image_inputs:
+ self.assertIsInstance(image, Image.Image)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
+ self.assertEqual(
+ encoded_images.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_numpy(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random numpy tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, np.ndarray)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
+ self.assertEqual(
+ encoded_images.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_call_pytorch(self):
+ # Initialize feature_extractor
+ feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test not batched input
+ encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
+ self.assertEqual(
+ encoded_images.shape,
+ (1, self.feature_extract_tester.num_channels, expected_height, expected_width),
+ )
+
+ # Test batched
+ encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
+
+ expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
+ self.assertEqual(
+ encoded_images.shape,
+ (
+ self.feature_extract_tester.batch_size,
+ self.feature_extract_tester.num_channels,
+ expected_height,
+ expected_width,
+ ),
+ )
+
+ def test_equivalence_pad_and_create_pixel_mask(self):
+ # Initialize feature_extractors
+ feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
+ feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
+ # create random PyTorch tensors
+ image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
+ for image in image_inputs:
+ self.assertIsInstance(image, torch.Tensor)
+
+ # Test whether the method "pad_and_return_pixel_mask" and calling the feature extractor return the same tensors
+ encoded_images_with_method = feature_extractor_1.pad_and_create_pixel_mask(image_inputs, return_tensors="pt")
+ encoded_images = feature_extractor_2(image_inputs, return_tensors="pt")
+
+ self.assertTrue(
+ torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
+ )
+ self.assertTrue(
+ torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
+ )
diff --git a/tests/test_modeling_vilt.py b/tests/test_modeling_vilt.py
new file mode 100644
index 0000000000..e9eca63adc
--- /dev/null
+++ b/tests/test_modeling_vilt.py
@@ -0,0 +1,607 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Testing suite for the PyTorch ViLT model. """
+
+import unittest
+
+from datasets import load_dataset
+
+from transformers import ViltConfig, is_torch_available, is_vision_available
+from transformers.file_utils import cached_property
+from transformers.models.auto import get_values
+from transformers.testing_utils import require_torch, require_vision, slow, torch_device
+
+from .test_configuration_common import ConfigTester
+from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ MODEL_MAPPING,
+ ViltForImageAndTextRetrieval,
+ ViltForImagesAndTextClassification,
+ ViltForMaskedLM,
+ ViltForQuestionAnswering,
+ ViltModel,
+ )
+ from transformers.models.vilt.modeling_vilt import VILT_PRETRAINED_MODEL_ARCHIVE_LIST
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import ViltProcessor
+
+
+class ViltModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ seq_length=7,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_input_mask=True,
+ use_token_type_ids=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ scope=None,
+ modality_type_vocab_size=2,
+ add_multiple_images=False,
+ num_images=-1,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_token_type_ids = use_token_type_ids
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.scope = scope
+ self.modality_type_vocab_size = modality_type_vocab_size
+ self.add_multiple_images = add_multiple_images
+ self.num_images = num_images
+ # we set the expected sequence length (which is used in several tests)
+ # this is equal to the seq length of the text tokens + number of image patches + 1 for the CLS token
+ self.expected_seq_len = self.seq_length + (self.image_size // self.patch_size) ** 2 + 1
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+ if self.add_multiple_images:
+ pixel_values = floats_tensor([self.batch_size, 2, self.num_channels, self.image_size, self.image_size])
+ else:
+ pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ if self.use_labels:
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+
+ config = self.get_config()
+
+ return (config, input_ids, token_type_ids, input_mask, pixel_values, token_labels)
+
+ def get_config(self):
+ return ViltConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ num_labels=self.num_labels,
+ modality_type_vocab_size=self.modality_type_vocab_size,
+ num_images=self.num_images,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ pixel_values,
+ token_labels,
+ ):
+ model = ViltModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, pixel_values=pixel_values)
+ result = model(input_ids, token_type_ids=token_type_ids, pixel_values=pixel_values)
+ result = model(input_ids, pixel_values=pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape, (self.batch_size, self.expected_seq_len, self.hidden_size)
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ pixel_values,
+ token_labels,
+ ) = config_and_inputs
+ inputs_dict = {
+ "input_ids": input_ids,
+ "token_type_ids": token_type_ids,
+ "attention_mask": input_mask,
+ "pixel_values": pixel_values,
+ }
+ return config, inputs_dict
+
+ def prepare_pixel_values(self):
+ return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
+
+
+@require_torch
+class ViltModelTest(ModelTesterMixin, unittest.TestCase):
+
+ all_model_classes = (
+ (
+ ViltModel,
+ ViltForQuestionAnswering,
+ ViltForImageAndTextRetrieval,
+ ViltForMaskedLM,
+ )
+ if is_torch_available()
+ else ()
+ )
+ test_pruning = False
+ test_headmasking = False
+ test_torchscript = False
+
+ # ViltForMaskedLM, ViltForQuestionAnswering and ViltForImagesAndTextClassification require special treatment
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+ inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+
+ # if model_class.__name__ == "ViltForNaturalLanguageVisualReasonining":
+ # inputs_dict["pixel_values"] = floats_tensor([self.model_tester.batch_size, self.model_tester.num_images, self.model_tester.num_channels, self.model_tester.image_size, self.model_tester.image_size])
+
+ if return_labels:
+ if model_class.__name__ == "ViltForQuestionAnswering":
+ inputs_dict["labels"] = torch.zeros(
+ self.model_tester.batch_size, self.model_tester.num_labels, device=torch_device
+ )
+ elif model_class.__name__ == "ViltForMaskedLM":
+ inputs_dict["labels"] = torch.zeros(
+ (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
+ )
+ elif model_class.__name__ == "ViltForImagesAndTextClassification":
+ inputs_dict["labels"] = torch.zeros(
+ self.model_tester.batch_size, dtype=torch.long, device=torch_device
+ )
+
+ return inputs_dict
+
+ def setUp(self):
+ self.model_tester = ViltModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=ViltConfig, hidden_size=37)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_training(self):
+ if not self.model_tester.is_training:
+ return
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ config.modality_type_vocab_size = 3
+
+ # ViltForImageAndTextRetrieval doesn't support training for now
+ if model_class in [*get_values(MODEL_MAPPING), ViltForImageAndTextRetrieval]:
+ continue
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ for k, v in inputs.items():
+ print(k, v.shape)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ def test_training_gradient_checkpointing(self):
+ if not self.model_tester.is_training:
+ return
+
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.use_cache = False
+ config.return_dict = True
+
+ # ViltForImageAndTextRetrieval doesn't support training for now
+ if (
+ model_class in [*get_values(MODEL_MAPPING), ViltForImageAndTextRetrieval]
+ or not model_class.supports_gradient_checkpointing
+ ):
+ continue
+
+ model = model_class(config)
+ model.to(torch_device)
+ model.gradient_checkpointing_enable()
+ model.train()
+ inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+ loss = model(**inputs).loss
+ loss.backward()
+
+ @unittest.skip(
+ reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
+ hidden states"""
+ )
+ def test_save_load(self):
+ pass
+
+ @unittest.skip(
+ reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
+ hidden states"""
+ )
+ def test_determinism(self):
+ pass
+
+ @unittest.skip(
+ reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
+ hidden states"""
+ )
+ def test_model_outputs_equivalence(self):
+ pass
+
+ def test_attention_outputs(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "expected_seq_len", None)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ # attentions are a list of length num_images
+ # each element contains the attentions of a particular image index
+ self.assertEqual(len(attentions), self.model_tester.num_images)
+ self.assertEqual(len(attentions[0]), self.model_tester.num_hidden_layers)
+ else:
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ # attentions are a list of length num_images
+ # each element contains the attentions of a particular image index
+ self.assertEqual(len(attentions), self.model_tester.num_images)
+ self.assertEqual(len(attentions[0]), self.model_tester.num_hidden_layers)
+ else:
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
+
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ self.assertListEqual(
+ list(attentions[0][0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+ else:
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ self.assertEqual(out_len + 1, len(outputs))
+
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
+
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ self.assertEqual(len(self_attentions), self.model_tester.num_images)
+ self.assertEqual(len(self_attentions[0]), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0][0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+ else:
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, seq_len, seq_len],
+ )
+
+ def test_hidden_states_output(self):
+ def check_hidden_states_output(inputs_dict, config, model_class):
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
+
+ expected_num_layers = getattr(
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
+ )
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ # hidden_states are a list of length num_images
+ # each element contains the hidden states of a particular image index
+ self.assertEqual(len(hidden_states), self.model_tester.num_images)
+ self.assertEqual(len(hidden_states[0]), expected_num_layers)
+ else:
+ self.assertEqual(len(hidden_states), expected_num_layers)
+
+ seq_length = self.model_tester.expected_seq_len
+
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ self.assertListEqual(
+ list(hidden_states[0][0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+ else:
+ self.assertListEqual(
+ list(hidden_states[0].shape[-2:]),
+ [seq_length, self.model_tester.hidden_size],
+ )
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ print("Model class:", model_class)
+ inputs_dict["output_hidden_states"] = True
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ # check that output_hidden_states also work using config
+ del inputs_dict["output_hidden_states"]
+ config.output_hidden_states = True
+
+ check_hidden_states_output(inputs_dict, config, model_class)
+
+ def test_retain_grad_hidden_states_attentions(self):
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.output_hidden_states = True
+ config.output_attentions = True
+
+ # no need to test all models as different heads yield the same functionality
+ model_class = self.all_model_classes[0]
+ model = model_class(config)
+ model.to(torch_device)
+
+ inputs = self._prepare_for_class(inputs_dict, model_class)
+
+ outputs = model(**inputs)
+
+ output = outputs[0]
+
+ # Encoder-/Decoder-only models
+ hidden_states = outputs.hidden_states[0]
+ attentions = outputs.attentions[0]
+
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ # hidden_states are a list of length num_images
+ # each element contains the hidden states of a particular image index
+ hidden_states[0].retain_grad()
+ attentions[0].retain_grad()
+ else:
+ hidden_states.retain_grad()
+ attentions.retain_grad()
+
+ output.flatten()[0].backward(retain_graph=True)
+
+ if model_class.__name__ == "ViltForImagesAndTextClassification":
+ # hidden_states are a list of length num_images
+ # each element contains the hidden states of a particular image index
+ self.assertIsNotNone(hidden_states[0].grad)
+ self.assertIsNotNone(attentions[0].grad)
+ else:
+ self.assertIsNotNone(hidden_states.grad)
+ self.assertIsNotNone(attentions.grad)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in VILT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
+ model = ViltModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+@require_torch
+class ViltForImagesAndTextClassificationModelTest(ViltModelTest, unittest.TestCase):
+
+ all_model_classes = (ViltForImagesAndTextClassification,) if is_torch_available() else ()
+
+ def setUp(self):
+ self.model_tester = ViltModelTester(self, modality_type_vocab_size=3, add_multiple_images=True, num_images=2)
+ self.config_tester = ConfigTester(self, config_class=ViltConfig, hidden_size=37)
+
+ @unittest.skip("We only test the model that takes in multiple images")
+ def test_model(self):
+ pass
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class ViltModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_processor(self):
+ return ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") if is_vision_available() else None
+
+ @slow
+ def test_inference_masked_lm(self):
+ model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm").to(torch_device)
+
+ processor = self.default_processor
+ image = prepare_img()
+ text = "a bunch of [MASK] laying on a [MASK]."
+ inputs = processor(image, text, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size([1, 11, 30522])
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-12.5061, -12.5123, -12.5174]).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.logits[0, 0, :3], expected_slice, atol=1e-4))
+
+ # verify masked token prediction equals "cats"
+ predicted_id = outputs.logits[0, 4, :].argmax(-1).item()
+ assert processor.decode([predicted_id]) == "cats"
+
+ @slow
+ def test_inference_visual_question_answering(self):
+ model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(torch_device)
+
+ processor = self.default_processor
+ image = prepare_img()
+ text = "How many cats are there?"
+ inputs = processor(image, text, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 3129))
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-15.9495, -18.1472, -10.3041]).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
+
+ # compute loss
+ vqa_labels = [[2, 3, 155, 800]]
+ vqa_scores = [[1.0, 0.3, 0.3, 0.3]]
+ labels = torch.zeros(1, model.config.num_labels).to(torch_device)
+
+ for i, (labels_example, scores_example) in enumerate(zip(vqa_labels, vqa_scores)):
+ for l, s in zip(labels_example, scores_example):
+ labels[i, l] = s
+
+ # forward pass
+ outputs = model(**inputs, labels=labels)
+
+ # verify we have a positive loss
+ self.assertTrue(outputs.loss > 0)
+
+ @slow
+ def test_inference_natural_language_visual_reasoning(self):
+ model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2").to(
+ torch_device
+ )
+
+ processor = self.default_processor
+
+ dataset = load_dataset("hf-internal-testing/fixtures_nlvr2", split="test")
+ image1 = Image.open(dataset[0]["file"]).convert("RGB")
+ image2 = Image.open(dataset[1]["file"]).convert("RGB")
+
+ text = "The left image contains twice the number of dogs as the right image, and at least two dogs in total are standing."
+ encoding_1 = processor(image1, text, return_tensors="pt")
+ encoding_2 = processor(image2, text, return_tensors="pt")
+
+ pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1)
+
+ # forward pass
+ outputs = model(
+ input_ids=encoding_1.input_ids,
+ pixel_values=pixel_values,
+ )
+
+ # verify the logits
+ expected_shape = torch.Size([1, 2])
+ self.assertEqual(outputs.logits.shape, expected_shape)
+
+ expected_slice = torch.tensor([-2.4013, 2.9342]).to(torch_device)
+ self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
diff --git a/utils/check_repo.py b/utils/check_repo.py
index a1ecb75f42..5ecdfeaea6 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -108,6 +108,10 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
+ "ViltForQuestionAnswering",
+ "ViltForImagesAndTextClassification",
+ "ViltForImageAndTextRetrieval",
+ "ViltForMaskedLM",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",