add sdpa to ViT [follow up of #29325] (#30555)

remove blank line (+1 squashed commit)
Squashed commits:
[24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits)
Squashed commits:
[08bd27e7a] [run-slow]vit_msn,vision_encoder_decoder
[ec96a8db3] [run-slow]vit_msn
[ead817eca] fix vit msn multi gpu
[d12cdc8fd] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[3fdbfa88f] doc
[a3ff33e4a] finish implementation
[e20b7b7fb] Update test_modeling_common.py
[e290c5810] Update test_modeling_flax_common.py
[d3af86f46] comment
[ff7dd32d8] more comments
[59b137889] suggestion
[7e2ba6d67] attn_implementation as attribute of the class
[fe66ab71f] minor
[38642b568] Apply suggestions from code review

Accept comments

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[22cde7d52] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[48e137cc6] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[99f4c679f] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[96cf20a6d] Update src/transformers/models/vit_msn/modeling_vit_msn.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[c59377d23] Update src/transformers/models/vit_mae/modeling_vit_mae.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[b70a47259] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[00c84d216] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[61f00ebb0] all tests are passing locally
[e9e0b82b7] vision encoder/decoder
[4d5076b56] test-vision (+20 squashed commits)
Squashed commits:
[d1add8db9] yolo
[9fde65716] fix flax
[986566c28] minor
[ca2f21d1f] vit
[3333efd7a] easy models change
[ebfc21402] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[b8b8603ed] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[48ecc7e26] all tests are passing locally
[bff7fc366] minor
[62f88306f] fix yolo and text_encoder tests
[121507555] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1064cae0a] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[b7f52ff3a] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[cffaa10dd] fix-copies
[ef6c511c4] test vit hybrid
[7d4ba8644] vit hybrid
[66f919033] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1fcc0a031] fixes
[cfde6eb21] fixup
[e77df1ed3] all except yolo end encoder decoder (+17 squashed commits)
Squashed commits:
[602913e22] vit + vit_mae are working
[547f6c4cc] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/  passes
[61a97dfa9] it s the complete opposite...
[aefab37d4] fix more tests
[71802a1b9] fix all torch tests
[40b12eb58] encoder - decoder tests
[941552b69] slow decorator where appropriate
[14d055d80] has_attentions to yolo and msn
[3381fa19f] add correct name
[e261316a7] repo consistency
[31c6d0c08] fixup
[9d214276c] minor fix
[11ed2e1b7] chore
[eca6644c4] add sdpa to vit-based models
[cffbf390b] make fix-copies result
[6468319b0] fix style
[d324cd02a] add sdpa for vit

Co-authored-by: Liubov Yaronskaya <luba.yaronskaya@gmail.com>
This commit is contained in:
hyenal 2024-05-16 10:56:11 +01:00 committed by GitHub
parent 9fd606dbdb
commit 1c21f48a50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 709 additions and 26 deletions

View File

@ -43,6 +43,34 @@ the authors compute the stats for a downstream dataset.
- Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the
[PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ASTForAudioClassification
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MIT/ast-finetuned-audioset-10-10-0.4593` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 27 | 6 | 4.5 |
| 2 | 12 | 6 | 2 |
| 4 | 21 | 8 | 2.62 |
| 8 | 40 | 14 | 2.86 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer.

View File

@ -68,6 +68,34 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso
*facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to
prepare images for the model.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import DeiTForImageClassification
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/deit-base-distilled-patch16-224` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 8 | 6 | 1.33 |
| 2 | 9 | 6 | 1.5 |
| 4 | 9 | 6 | 1.5 |
| 8 | 8 | 6 | 1.33 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DeiT.

View File

@ -33,6 +33,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/MCG-NJU/VideoMAE).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import VideoMAEForVideoClassification
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MCG-NJU/videomae-base-finetuned-kinetics` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 37 | 10 | 3.7 |
| 2 | 24 | 18 | 1.33 |
| 4 | 43 | 32 | 1.34 |
| 8 | 84 | 60 | 1.4 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VideoMAE. If

View File

@ -88,6 +88,34 @@ who already converted the weights from JAX to PyTorch. Credits go to him!
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources
Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer).

View File

@ -39,6 +39,34 @@ substantially fewer computational resources to train.*
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTHybridForImageClassification
model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-hybrid-base-bit-384` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 29 | 18 | 1.61 |
| 2 | 26 | 18 | 1.44 |
| 4 | 25 | 18 | 1.39 |
| 8 | 34 | 24 | 1.42 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid.

View File

@ -52,6 +52,34 @@ consists of Transformer blocks) takes as input. Each mask token is a shared, lea
sin/cos position embeddings are added both to the input of the encoder and the decoder.
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTMAEModel
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-mae-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 11 | 6 | 1.83 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTMAE.

View File

@ -49,6 +49,34 @@ use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMS
- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K
labels when fine-tuned.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTMSNForImageClassification
model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-msn-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT MSN.

View File

@ -32,6 +32,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import AutoModelForObjectDetection
model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `hustvl/yolos-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 106 | 76 | 1.39 |
| 2 | 154 | 90 | 1.71 |
| 4 | 222 | 116 | 1.91 |
| 8 | 368 | 168 | 2.19 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with YOLOS.

View File

@ -192,10 +192,12 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
For now, Transformers supports SDPA inference and training for the following architectures:
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
@ -216,12 +218,18 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
<Tip>

View File

@ -169,6 +169,38 @@ class ASTSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
class ASTSdpaSelfAttention(ASTSelfAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
class ASTSelfOutput(nn.Module):
"""
@ -228,6 +260,13 @@ class ASTAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
class ASTSdpaAttention(ASTAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention = ASTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None:
@ -261,7 +300,13 @@ class ASTOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST
AST_ATTENTION_CLASSES = {
"eager": ASTAttention,
"sdpa": ASTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -269,7 +314,7 @@ class ASTLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ASTAttention(config)
self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ASTIntermediate(config)
self.output = ASTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -366,6 +411,7 @@ class ASTPreTrainedModel(PreTrainedModel):
base_model_prefix = "audio_spectrogram_transformer"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:

View File

@ -190,6 +190,38 @@ class DeiTSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT
class DeiTSdpaSelfAttention(DeiTSelfAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
class DeiTSelfOutput(nn.Module):
"""
@ -249,6 +281,13 @@ class DeiTAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT
class DeiTSdpaAttention(DeiTAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention = DeiTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module):
def __init__(self, config: DeiTConfig) -> None:
@ -282,7 +321,13 @@ class DeiTOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT
DEIT_ATTENTION_CLASSES = {
"eager": DeiTAttention,
"sdpa": DeiTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -290,7 +335,7 @@ class DeiTLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = DeiTAttention(config)
self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = DeiTIntermediate(config)
self.output = DeiTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -388,6 +433,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["DeiTLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

50
src/transformers/models/videomae/modeling_videomae.py Normal file → Executable file
View File

@ -134,7 +134,6 @@ class VideoMAEEmbeddings(nn.Module):
# add position embeddings
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
# only keep visible patches
# ~bool_masked_pos means visible
if bool_masked_pos is not None:
@ -268,6 +267,40 @@ class VideoMAESelfAttention(nn.Module):
return outputs
class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
key_layer = self.transpose_for_scores(keys)
value_layer = self.transpose_for_scores(values)
query_layer = self.transpose_for_scores(queries)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
class VideoMAESelfOutput(nn.Module):
"""
@ -327,6 +360,13 @@ class VideoMAEAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
class VideoMAESdpaAttention(VideoMAEAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention = VideoMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
class VideoMAEIntermediate(nn.Module):
def __init__(self, config: VideoMAEConfig) -> None:
@ -360,7 +400,10 @@ class VideoMAEOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE
VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
class VideoMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -368,7 +411,7 @@ class VideoMAELayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VideoMAEAttention(config)
self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VideoMAEIntermediate(config)
self.output = VideoMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -465,6 +508,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "videomae"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -336,8 +336,20 @@ class VisionEncoderDecoderModel(PreTrainedModel):
del tf_model
gc.collect()
attn_implementation = kwargs.get("attn_implementation", None)
kwargs_encoder_decoder = {}
if attn_implementation:
kwargs_encoder_decoder = {
"encoder_attn_implementation": attn_implementation,
"decoder_attn_implementation": attn_implementation,
}
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
encoder_dir,
decoder_dir,
encoder_from_tf=True,
decoder_from_tf=True,
**kwargs_encoder_decoder,
)
# This is only for copying some specific attributes of this particular model.
model.config = config

View File

@ -236,6 +236,37 @@ class ViTSelfAttention(nn.Module):
return outputs
class ViTSdpaSelfAttention(ViTSelfAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
class ViTSelfOutput(nn.Module):
"""
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
@ -293,6 +324,12 @@ class ViTAttention(nn.Module):
return outputs
class ViTSdpaAttention(ViTAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention = ViTSdpaSelfAttention(config)
class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None:
super().__init__()
@ -324,6 +361,12 @@ class ViTOutput(nn.Module):
return hidden_states
VIT_ATTENTION_CLASSES = {
"eager": ViTAttention,
"sdpa": ViTSdpaAttention,
}
class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -331,7 +374,7 @@ class ViTLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTAttention(config)
self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -428,6 +471,7 @@ class ViTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@ -248,6 +248,38 @@ class ViTHybridSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTHybrid
class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid
class ViTHybridSelfOutput(nn.Module):
"""
@ -307,6 +339,13 @@ class ViTHybridAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTHybrid
class ViTHybridSdpaAttention(ViTHybridAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention = ViTHybridSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid
class ViTHybridIntermediate(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
@ -340,6 +379,12 @@ class ViTHybridOutput(nn.Module):
return hidden_states
VIT_HYBRID_ATTENTION_CLASSES = {
"eager": ViTHybridAttention,
"sdpa": ViTHybridSdpaAttention,
}
class ViTHybridLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -347,7 +392,7 @@ class ViTHybridLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTHybridAttention(config)
self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTHybridIntermediate(config)
self.output = ViTHybridOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -447,6 +492,7 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@ -241,8 +241,8 @@ class ViTMAEEmbeddings(nn.Module):
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
@ -370,6 +370,38 @@ class ViTMAESelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module):
"""
@ -429,6 +461,13 @@ class ViTMAEAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
class ViTMAESdpaAttention(ViTMAEAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention = ViTMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None:
@ -462,7 +501,13 @@ class ViTMAEOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE
VITMAE_ATTENTION_CLASSES = {
"eager": ViTMAEAttention,
"sdpa": ViTMAESdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -470,7 +515,7 @@ class ViTMAELayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTMAEAttention(config)
self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -567,6 +612,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -764,7 +810,8 @@ class ViTMAEDecoder(nn.Module):
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# unshuffle
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed

View File

@ -222,6 +222,38 @@ class ViTMSNSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN
class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
class ViTMSNSelfOutput(nn.Module):
"""
@ -281,6 +313,13 @@ class ViTMSNAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN
class ViTMSNSdpaAttention(ViTMSNAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention = ViTMSNSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
class ViTMSNIntermediate(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
@ -314,7 +353,10 @@ class ViTMSNOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN
VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
class ViTMSNLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -322,7 +364,7 @@ class ViTMSNLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTMSNAttention(config)
self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMSNIntermediate(config)
self.output = ViTMSNOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -419,7 +461,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTMSNAttention"]
_no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
_supports_sdpa = True
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
# when creating pre-training scripts.

View File

@ -307,6 +307,38 @@ class YolosSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos
class YolosSdpaSelfAttention(YolosSelfAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
class YolosSelfOutput(nn.Module):
"""
@ -366,6 +398,13 @@ class YolosAttention(nn.Module):
return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos
class YolosSdpaAttention(YolosAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention = YolosSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
class YolosIntermediate(nn.Module):
def __init__(self, config: YolosConfig) -> None:
@ -399,7 +438,10 @@ class YolosOutput(nn.Module):
return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos
YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
class YolosLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""
@ -407,7 +449,7 @@ class YolosLayer(nn.Module):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = YolosAttention(config)
self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = YolosIntermediate(config)
self.output = YolosOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -531,6 +573,7 @@ class YolosPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@ -63,6 +63,7 @@ class ASTModelTester:
scope=None,
frequency_stride=2,
time_stride=2,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -83,6 +84,7 @@ class ASTModelTester:
self.scope = scope
self.frequency_stride = frequency_stride
self.time_stride = time_stride
self.attn_implementation = attn_implementation
# in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1
@ -117,6 +119,7 @@ class ASTModelTester:
initializer_range=self.initializer_range,
frequency_stride=self.frequency_stride,
time_stride=self.time_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, input_values, labels):

View File

@ -80,6 +80,8 @@ class DeiTModelTester:
num_labels=3,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -99,10 +101,14 @@ class DeiTModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 2
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -130,6 +136,7 @@ class DeiTModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -71,6 +71,7 @@ class TFDeiTModelTester:
num_labels=3,
scope=None,
encoder_stride=2,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -90,6 +91,7 @@ class TFDeiTModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2
@ -121,6 +123,7 @@ class TFDeiTModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -70,6 +70,7 @@ class VideoMAEModelTester:
initializer_range=0.02,
mask_ratio=0.9,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -91,6 +92,7 @@ class VideoMAEModelTester:
self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope
self.attn_implementation = attn_implementation
# in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame
self.num_patches_per_frame = (image_size // patch_size) ** 2
@ -132,6 +134,7 @@ class VideoMAEModelTester:
decoder_intermediate_size=self.intermediate_size,
decoder_num_attention_heads=self.num_attention_heads,
decoder_num_hidden_layers=self.num_hidden_layers,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):
@ -197,7 +200,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# hence we define a single mask, which we then repeat for each example in the batch
mask = torch.ones((self.model_tester.num_masks,))
mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))])
bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool()
batch_size = inputs_dict["pixel_values"].shape[0]
bool_masked_pos = mask.expand(batch_size, -1).bool()
inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device)
if return_labels:

View File

@ -492,7 +492,9 @@ class TFVisionEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True)
pt_model = VisionEncoderDecoderModel.from_pretrained(
tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation
)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)

View File

@ -49,6 +49,7 @@ class FlaxViTModelTester(unittest.TestCase):
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -66,6 +67,7 @@ class FlaxViTModelTester(unittest.TestCase):
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
@ -87,6 +89,7 @@ class FlaxViTModelTester(unittest.TestCase):
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)
return config, pixel_values

View File

@ -63,6 +63,7 @@ class TFViTModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -81,6 +82,7 @@ class TFViTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
@ -111,6 +113,7 @@ class TFViTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -68,6 +68,8 @@ class ViTModelTester:
initializer_range=0.02,
scope=None,
encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -87,10 +89,14 @@ class ViTModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -118,6 +124,7 @@ class ViTModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -58,6 +58,7 @@ class ViTHybridModelTester:
initializer_range=0.02,
backbone_featmap_shape=[1, 16, 4, 4],
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -77,6 +78,7 @@ class ViTHybridModelTester:
self.initializer_range = initializer_range
self.scope = scope
self.backbone_featmap_shape = backbone_featmap_shape
self.attn_implementation = attn_implementation
# in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
# the number of patches is based on the feature map of the backbone, which by default uses an output stride
@ -122,6 +124,7 @@ class ViTHybridModelTester:
backbone_featmap_shape=self.backbone_featmap_shape,
backbone_config=backbone_config,
backbone=None,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -72,6 +72,7 @@ class TFViTMAEModelTester:
num_labels=3,
mask_ratio=0.6,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -91,6 +92,7 @@ class TFViTMAEModelTester:
self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope
self.attn_implementation = attn_implementation
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
@ -127,6 +129,7 @@ class TFViTMAEModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
mask_ratio=self.mask_ratio,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -63,8 +63,9 @@ class ViTMAEModelTester:
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
mask_ratio=0.6,
scope=None,
mask_ratio=0.5,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -84,11 +85,15 @@ class ViTMAEModelTester:
self.initializer_range = initializer_range
self.mask_ratio = mask_ratio
self.scope = scope
self.attn_implementation = attn_implementation
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -120,6 +125,7 @@ class ViTMAEModelTester:
decoder_intermediate_size=self.intermediate_size,
decoder_num_attention_heads=self.num_attention_heads,
decoder_num_hidden_layers=self.num_hidden_layers,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -59,6 +59,7 @@ class ViTMSNModelTester:
type_sequence_label_size=10,
initializer_range=0.02,
scope=None,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -77,6 +78,7 @@ class ViTMSNModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.attn_implementation = attn_implementation
# in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
@ -106,6 +108,7 @@ class ViTMSNModelTester:
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -62,6 +62,7 @@ class YolosModelTester:
scope=None,
n_targets=8,
num_detection_tokens=10,
attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@ -83,6 +84,7 @@ class YolosModelTester:
self.scope = scope
self.n_targets = n_targets
self.num_detection_tokens = num_detection_tokens
self.attn_implementation = attn_implementation
# we set the expected sequence length (which is used in several tests)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
@ -123,6 +125,7 @@ class YolosModelTester:
initializer_range=self.initializer_range,
num_detection_tokens=self.num_detection_tokens,
num_labels=self.num_labels,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels):

View File

@ -2788,7 +2788,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded = model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
@ -3724,6 +3726,11 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
@ -3861,6 +3868,27 @@ class ModelTesterMixin:
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if not deactivate_mask and (
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
):
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():

View File

@ -371,7 +371,9 @@ class FlaxModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded = pt_model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)

View File

@ -84,7 +84,7 @@ def check_sdpa_support_list():
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
if arch not in doctext and arch not in doctext.replace("-", "_"):
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)