Add Document Image Transformer (DiT) (#15984)
* Add conversion script * Improve script * Fix bug * Add option to push to hub * Add support for classification models * Update model name * Upload feature extractor files first * Remove hash checking * Fix config * Add id2label * Add import * Fix id2label file name * Fix expected shape * Add model to README * Improve docs * Add integration test and fix CI * Fix code style * Add missing init * Add model to SPECIAL_MODULE_TO_TEST_MAP Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
6c9010ef63
commit
0835119bf3
|
@ -252,6 +252,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
|
|||
1. **[Data2Vec](https://huggingface.co/docs/transformers/master/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/master/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DeiT](https://huggingface.co/docs/transformers/model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
|
||||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
|
|
|
@ -237,6 +237,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
|||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/master/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/master/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
|
|
|
@ -261,6 +261,7 @@ conda install -c huggingface transformers
|
|||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (来自 Facebook) 伴随论文 [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) 由 Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko 发布。
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (来自 Microsoft Research) 伴随论文 [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) 由 Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan 发布。
|
||||
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (来自 HuggingFace), 伴随论文 [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 同样的方法也应用于压缩 GPT-2 到 [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation), RoBERTa 到 [DistilRoBERTa](https://github.com/huggingface/transformers/tree/master/examples/distillation), Multilingual BERT 到 [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) 和德语版 DistilBERT。
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/master/model_doc/dit)** (来自 Microsoft Research) 伴随论文 [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) 由 Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei 发布。
|
||||
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (来自 Facebook) 伴随论文 [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) 由 Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih 发布。
|
||||
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (来自 Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning 发布。
|
||||
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (来自 Google Research) 伴随论文 [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) 由 Sascha Rothe, Shashi Narayan, Aliaksei Severyn 发布。
|
||||
|
|
|
@ -273,6 +273,7 @@ conda install -c huggingface transformers
|
|||
1. **[DETR](https://huggingface.co/docs/transformers/model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](https://huggingface.co/docs/transformers/model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
1. **[DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/master/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT.
|
||||
1. **[DiT](https://huggingface.co/docs/transformers/master/model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DPR](https://huggingface.co/docs/transformers/model_doc/dpr)** (from Facebook) released with the paper [Dense Passage Retrieval for Open-Domain Question Answering](https://arxiv.org/abs/2004.04906) by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[ELECTRA](https://huggingface.co/docs/transformers/model_doc/electra)** (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||
1. **[EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder-decoder)** (from Google Research) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
|
|
|
@ -194,6 +194,8 @@
|
|||
title: DialoGPT
|
||||
- local: model_doc/distilbert
|
||||
title: DistilBERT
|
||||
- local: model_doc/dit
|
||||
title: DiT
|
||||
- local: model_doc/dpr
|
||||
title: DPR
|
||||
- local: model_doc/electra
|
||||
|
|
|
@ -78,6 +78,7 @@ conversion utilities for the following models.
|
|||
1. **[Data2Vec](model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DeBERTa](model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DiT](model_doc/dit)** (from Microsoft Research) released with the paper [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
1. **[DeiT](model_doc/deit)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou.
|
||||
1. **[DETR](model_doc/detr)** (from Facebook) released with the paper [End-to-End Object Detection with Transformers](https://arxiv.org/abs/2005.12872) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko.
|
||||
1. **[DialoGPT](model_doc/dialogpt)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DiT
|
||||
|
||||
## Overview
|
||||
|
||||
DiT was proposed in [DiT: Self-supervised Pre-training for Document Image Transformer](https://arxiv.org/abs/2203.02378) by Junlong Li, Yiheng Xu, Tengchao Lv, Lei Cui, Cha Zhang, Furu Wei.
|
||||
DiT applies the self-supervised objective of [BEiT](beit) (BERT pre-training of Image Transformers) to 42 million document images, allowing for state-of-the-art results on tasks including:
|
||||
|
||||
- document image classification: the [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset (a collection of
|
||||
400,000 images belonging to one of 16 classes).
|
||||
- document layout analysis: the [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) dataset (a collection of more
|
||||
than 360,000 document images constructed by automatically parsing PubMed XML files).
|
||||
- table detection: the [ICDAR 2019 cTDaR](https://github.com/cndplab-founder/ICDAR2019_cTDaR) dataset (a collection of
|
||||
600 training images and 240 testing images).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Image Transformer has recently achieved significant progress for natural image understanding, either using supervised (ViT, DeiT, etc.) or self-supervised (BEiT, MAE, etc.) pre-training techniques. In this paper, we propose DiT, a self-supervised pre-trained Document Image Transformer model using large-scale unlabeled text images for Document AI tasks, which is essential since no supervised counterparts ever exist due to the lack of human labeled document images. We leverage DiT as the backbone network in a variety of vision-based Document AI tasks, including document image classification, document layout analysis, as well as table detection. Experiment results have illustrated that the self-supervised pre-trained DiT model achieves new state-of-the-art results on these downstream tasks, e.g. document image classification (91.11 → 92.69), document layout analysis (91.0 → 94.9) and table detection (94.23 → 96.55). *
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/dit_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> Summary of the approach. Taken from the [original paper](https://arxiv.org/abs/2203.02378). </small>
|
||||
|
||||
One can directly use the weights of DiT with the AutoModel API:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("microsoft/dit-base")
|
||||
```
|
||||
|
||||
This will load the model pre-trained on masked image modeling. Note that this won't include the language modeling head on top, used to predict visual tokens.
|
||||
|
||||
To include the head, you can load the weights into a `BeitForMaskedImageModeling` model, like so:
|
||||
|
||||
```python
|
||||
from transformers import BeitForMaskedImageModeling
|
||||
|
||||
model = BeitForMaskedImageModeling.from_pretrained("microsoft/dit-base")
|
||||
```
|
||||
|
||||
You can also load a fine-tuned model from the [hub](https://huggingface.co/models?other=dit), like so:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForImageClassification
|
||||
|
||||
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
||||
```
|
||||
|
||||
This particular checkpoint was fine-tuned on [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/), an important benchmark for document image classification.
|
||||
A notebook that illustrates inference for document image classification can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DiT/Inference_with_DiT_(Document_Image_Transformer)_for_document_image_classification.ipynb).
|
||||
|
||||
As DiT's architecture is equivalent to that of BEiT, one can refer to [BEiT's documentation page](beit) for all tips, code examples and notebooks.
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/microsoft/unilm/tree/master/dit).
|
|
@ -211,6 +211,7 @@ _import_structure = {
|
|||
"models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"],
|
||||
"models.dialogpt": [],
|
||||
"models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"],
|
||||
"models.dit": [],
|
||||
"models.dpr": [
|
||||
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"DPRConfig",
|
||||
|
|
|
@ -47,6 +47,7 @@ from . import (
|
|||
detr,
|
||||
dialogpt,
|
||||
distilbert,
|
||||
dit,
|
||||
dpr,
|
||||
electra,
|
||||
encoder_decoder,
|
||||
|
|
|
@ -330,6 +330,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||
("layoutxlm", "LayoutXLM"),
|
||||
("data2vec-audio", "Data2VecAudio"),
|
||||
("data2vec-text", "Data2VecText"),
|
||||
("dit", "DiT"),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert DiT checkpoints from the unilm repository."""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import requests
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
|
||||
rename_keys = []
|
||||
for i in range(config.num_hidden_layers):
|
||||
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
|
||||
)
|
||||
rename_keys.append(
|
||||
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
|
||||
)
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
|
||||
|
||||
# projection layer + position embeddings
|
||||
rename_keys.extend(
|
||||
[
|
||||
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
|
||||
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
|
||||
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
|
||||
(f"{prefix}pos_embed", "beit.embeddings.position_embeddings"),
|
||||
]
|
||||
)
|
||||
|
||||
if has_lm_head:
|
||||
# mask token + layernorm
|
||||
rename_keys.extend(
|
||||
[
|
||||
("mask_token", "beit.embeddings.mask_token"),
|
||||
("norm.weight", "layernorm.weight"),
|
||||
("norm.bias", "layernorm.bias"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# layernorm + classification head
|
||||
rename_keys.extend(
|
||||
[
|
||||
("fc_norm.weight", "beit.pooler.layernorm.weight"),
|
||||
("fc_norm.bias", "beit.pooler.layernorm.bias"),
|
||||
("head.weight", "classifier.weight"),
|
||||
("head.bias", "classifier.bias"),
|
||||
]
|
||||
)
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
|
||||
for i in range(config.num_hidden_layers):
|
||||
prefix = "backbone." if is_semantic else ""
|
||||
# queries, keys and values
|
||||
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
||||
: config.hidden_size, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
config.hidden_size : config.hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
||||
-config.hidden_size :, :
|
||||
]
|
||||
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
# gamma_1 and gamma_2
|
||||
# we call them lambda because otherwise they are renamed when using .from_pretrained
|
||||
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
|
||||
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
|
||||
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
|
||||
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_dit_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our BEiT structure.
|
||||
"""
|
||||
|
||||
# define default BEiT configuration
|
||||
has_lm_head = False if "rvlcdip" in checkpoint_url else True
|
||||
config = BeitConfig(use_absolute_position_embeddings=True, use_mask_token=has_lm_head)
|
||||
|
||||
# size of the architecture
|
||||
if "large" in checkpoint_url or "dit-l" in checkpoint_url:
|
||||
config.hidden_size = 1024
|
||||
config.intermediate_size = 4096
|
||||
config.num_hidden_layers = 24
|
||||
config.num_attention_heads = 16
|
||||
|
||||
# labels
|
||||
if "rvlcdip" in checkpoint_url:
|
||||
config.num_labels = 16
|
||||
repo_id = "datasets/huggingface/label-files"
|
||||
filename = "rvlcdip-id2label.json"
|
||||
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
|
||||
id2label = {int(k): v for k, v in id2label.items()}
|
||||
config.id2label = id2label
|
||||
config.label2id = {v: k for k, v in id2label.items()}
|
||||
|
||||
# load state_dict of original model, remove and rename some keys
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
|
||||
|
||||
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head)
|
||||
|
||||
# load HuggingFace model
|
||||
model = BeitForMaskedImageModeling(config) if has_lm_head else BeitForImageClassification(config)
|
||||
model.eval()
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Check outputs on an image
|
||||
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
|
||||
image = prepare_img()
|
||||
|
||||
encoding = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = encoding["pixel_values"]
|
||||
|
||||
outputs = model(pixel_values)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify logits
|
||||
expected_shape = [1, 16] if "rvlcdip" in checkpoint_url else [1, 196, 8192]
|
||||
assert logits.shape == torch.Size(expected_shape), "Shape of logits not as expected"
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
if has_lm_head:
|
||||
model_name = "dit-base" if "base" in checkpoint_url else "dit-large"
|
||||
else:
|
||||
model_name = "dit-base-finetuned-rvlcdip" if "dit-b" in checkpoint_url else "dit-large-finetuned-rvlcdip"
|
||||
feature_extractor.push_to_hub(
|
||||
repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
|
||||
organization="nielsr",
|
||||
commit_message="Add feature extractor",
|
||||
use_temp_dir=True,
|
||||
)
|
||||
model.push_to_hub(
|
||||
repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
|
||||
organization="nielsr",
|
||||
commit_message="Add model",
|
||||
use_temp_dir=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_dit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub)
|
|
@ -0,0 +1,61 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForImageClassification
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class DiTIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_for_image_classification(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
||||
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
||||
model.to(torch_device)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("nielsr/rvlcdip-demo")
|
||||
|
||||
image = dataset["train"][0]["image"].convert("RGB")
|
||||
|
||||
inputs = feature_extractor(image, return_tensors="pt")
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
expected_shape = torch.Size((1, 16))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[-0.4158, -0.4092, -0.4347],
|
||||
device=torch_device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4))
|
|
@ -276,6 +276,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
|
|||
"auto/test_modeling_auto.py",
|
||||
"auto/test_modeling_tf_pytorch.py",
|
||||
"bort/test_modeling_bort.py",
|
||||
"dit/test_modeling_dit.py",
|
||||
],
|
||||
"models/auto/modeling_flax_auto.py": "auto/test_modeling_flax_auto.py",
|
||||
"models/auto/modeling_tf_auto.py": [
|
||||
|
|
Loading…
Reference in New Issue