[`BERT`] Add support for sdpa (#28802)

* Adding SDPA support for BERT

* Using the proper input name for testing model input in inference()

* Adding documentation for SDPA in BERT model page

* Use the stable link for the documentation

* Adding a gate to only call .contiguous() for torch < 2.2.0

* Additions and fixes to the documentation

* Minor updates to documentation

* Adding extra requirements needed for the contiguous() bug

* Adding "Adapted from" in plcae of the "Copied from"

* Add benchmark speedup tables to the documentation

* Minor fixes to the documentation

* Use ClapText as a replacemenet for Bert in the Copied-From

* Some more fixes for the fix-copies references

* Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage

[test all]

* Undo changes to separate test

* Refactored SDPA self attention code for KV projections

* Change use_sdpa to attn_implementation

* Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models)
This commit is contained in:
JB (Don) 2024-04-26 23:23:44 +08:00 committed by GitHub
parent 2de5cb12be
commit dfa7b580e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 495 additions and 86 deletions

View File

@ -61,6 +61,53 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
- The model must predict the original sentence, but has a second objective: inputs are two sentences A and B (with a separation token in between). With probability 50%, the sentences are consecutive in the corpus, in the remaining 50% they are not related. The model has to predict if the sentences are consecutive or not.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04) with `float16`, we saw the
following speedups during training and inference.
#### Training
|batch_size|seq_len|Time per batch (eager - s)|Time per batch (sdpa - s)|Speedup (%)|Eager peak mem (MB)|sdpa peak mem (MB)|Mem saving (%)|
|----------|-------|--------------------------|-------------------------|-----------|-------------------|------------------|--------------|
|4 |256 |0.023 |0.017 |35.472 |939.213 |764.834 |22.800 |
|4 |512 |0.023 |0.018 |23.687 |1970.447 |1227.162 |60.569 |
|8 |256 |0.023 |0.018 |23.491 |1594.295 |1226.114 |30.028 |
|8 |512 |0.035 |0.025 |43.058 |3629.401 |2134.262 |70.054 |
|16 |256 |0.030 |0.024 |25.583 |2874.426 |2134.262 |34.680 |
|16 |512 |0.064 |0.044 |46.223 |6964.659 |3961.013 |75.830 |
#### Inference
|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%)|Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|----------|-------|----------------------------|---------------------------|-----------|--------------|-----------|-------------|
|1 |128 |5.736 |4.987 |15.022 |282.661 |282.924 |-0.093 |
|1 |256 |5.689 |4.945 |15.055 |298.686 |298.948 |-0.088 |
|2 |128 |6.154 |4.982 |23.521 |314.523 |314.785 |-0.083 |
|2 |256 |6.201 |4.949 |25.303 |347.546 |347.033 |0.148 |
|4 |128 |6.049 |4.987 |21.305 |378.895 |379.301 |-0.107 |
|4 |256 |6.285 |5.364 |17.166 |443.209 |444.382 |-0.264 |
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

View File

@ -187,10 +187,11 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
## PyTorch scaled dot product attention
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available.
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
For now, Transformers supports SDPA inference and training for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
@ -224,6 +225,13 @@ FlashAttention can only be used for models with the `fp16` or `bf16` torch type,
</Tip>
<Tip>
SDPA does not support certain sets of attention parameters, such as `head_mask` and `output_attentions=True`.
In that case, you should see a warning message and we will fall back to the (slower) eager implementation.
</Tip>
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
```diff
@ -232,8 +240,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()
input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

View File

@ -445,10 +445,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if torch.all(mask == 1):
if is_tracing:
pass
elif tgt_len == 1:
if not is_tracing and torch.all(mask == 1):
if tgt_len == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
return None
elif key_value_length == tgt_len:

View File

@ -883,11 +883,18 @@ class AlignTextSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
"eager": AlignTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
class AlignTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = AlignTextSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = AlignTextSelfOutput(config)
self.pruned_heads = set()

View File

@ -434,11 +434,18 @@ class AltRobertaSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta
ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
"eager": AltRobertaSelfAttention,
}
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
class AltRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = AltRobertaSelfOutput(config)
self.pruned_heads = set()
@ -1205,7 +1212,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
config_class = AltCLIPTextConfig
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->AltRoberta
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->AltRoberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -1232,7 +1239,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -23,10 +23,15 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
@ -45,6 +50,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
get_torch_version,
logging,
replace_return_docstrings,
)
@ -350,6 +356,103 @@ class BertSelfAttention(nn.Module):
return outputs
class BertSdpaSelfAttention(BertSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type=position_embedding_type)
self.dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
# Adapted from BertSelfAttention
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
bsz, tgt_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal
# mask in case tgt_len == 1.
is_causal = self.is_decoder and attention_mask is None and tgt_len > 1
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
@ -364,10 +467,18 @@ class BertSelfOutput(nn.Module):
return hidden_states
BERT_SELF_ATTENTION_CLASSES = {
"eager": BertSelfAttention,
"sdpa": BertSdpaSelfAttention,
}
class BertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
@ -715,6 +826,7 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -859,6 +971,9 @@ class BertModel(BertPreTrainedModel):
self.pooler = BertPooler(config) if add_pooling_layer else None
self.attn_implementation = config._attn_implementation
self.position_embedding_type = config.position_embedding_type
# Initialize weights and apply final processing
self.post_init()
@ -945,9 +1060,6 @@ class BertModel(BertPreTrainedModel):
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
@ -956,9 +1068,43 @@ class BertModel(BertPreTrainedModel):
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
use_sdpa_attention_masks = (
self.attn_implementation == "sdpa"
and self.position_embedding_type == "absolute"
and head_mask is None
and not output_attentions
)
# Expand the attention mask
if use_sdpa_attention_masks:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder:
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
input_shape,
embedding_output,
past_key_values_length,
)
else:
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@ -967,7 +1113,15 @@ class BertModel(BertPreTrainedModel):
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
if use_sdpa_attention_masks:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
@ -978,13 +1132,6 @@ class BertModel(BertPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,

View File

@ -192,11 +192,18 @@ class BertGenerationSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration
BERT_GENERATION_SELF_ATTENTION_CLASSES = {
"eager": BertGenerationSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION
class BertGenerationAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertGenerationSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = BERT_GENERATION_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = BertGenerationSelfOutput(config)
self.pruned_heads = set()

View File

@ -562,11 +562,18 @@ class BridgeTowerSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower
BRIDGE_TOWER_SELF_ATTENTION_CLASSES = {
"eager": BridgeTowerSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
class BridgeTowerAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BridgeTowerSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = BridgeTowerSelfOutput(config)
self.pruned_heads = set()

View File

@ -312,11 +312,18 @@ class CamembertSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert
CAMEMBERT_SELF_ATTENTION_CLASSES = {
"eager": CamembertSelfAttention,
}
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT
class CamembertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = CamembertSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = CamembertSelfOutput(config)
self.pruned_heads = set()
@ -745,7 +752,7 @@ class CamembertModel(CamembertPreTrainedModel):
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -778,7 +785,7 @@ class CamembertModel(CamembertPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -354,11 +354,18 @@ class ChineseCLIPTextSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
"eager": ChineseCLIPTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
class ChineseCLIPTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ChineseCLIPTextSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = ChineseCLIPTextSelfOutput(config)
self.pruned_heads = set()

View File

@ -1376,11 +1376,18 @@ class ClapTextSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText
CLAP_TEXT_SELF_ATTENTION_CLASSES = {
"eager": ClapTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
class ClapTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ClapTextSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = ClapTextSelfOutput(config)
self.pruned_heads = set()
@ -1763,7 +1770,6 @@ class ClapTextModel(ClapPreTrainedModel):
config_class = ClapTextConfig
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->ClapText
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -1782,7 +1788,6 @@ class ClapTextModel(ClapPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -298,11 +298,18 @@ class Data2VecTextSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText
DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
"eager": Data2VecTextSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
class Data2VecTextAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = Data2VecTextSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = Data2VecTextSelfOutput(config)
self.pruned_heads = set()
@ -727,7 +734,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -355,11 +355,18 @@ class ElectraSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra
ELECTRA_SELF_ATTENTION_CLASSES = {
"eager": ElectraSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA
class ElectraAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = ElectraSelfOutput(config)
self.pruned_heads = set()

View File

@ -285,11 +285,18 @@ class ErnieSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie
ERNIE_SELF_ATTENTION_CLASSES = {
"eager": ErnieSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE
class ErnieAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ErnieSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = ErnieSelfOutput(config)
self.pruned_heads = set()
@ -787,7 +794,7 @@ class ErnieModel(ErniePreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Ernie
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Ernie
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config

View File

@ -267,11 +267,18 @@ class GitSelfOutput(nn.Module):
return hidden_states
GIT_SELF_ATTENTION_CLASSES = {
"eager": GitSelfAttention,
}
class GitAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git,BERT->GIT
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = GitSelfOutput(config)
self.pruned_heads = set()

View File

@ -276,11 +276,18 @@ class LayoutLMSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM
LAYOUTLM_SELF_ATTENTION_CLASSES = {
"eager": LayoutLMSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
class LayoutLMAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = LayoutLMSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = LayoutLMSelfOutput(config)
self.pruned_heads = set()

View File

@ -468,11 +468,18 @@ class MarkupLMSelfAttention(nn.Module):
return outputs
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM
MARKUPLM_SELF_ATTENTION_CLASSES = {
"eager": MarkupLMSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
class MarkupLMAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = MarkupLMSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = MarkupLMSelfOutput(config)
self.pruned_heads = set()
@ -797,7 +804,7 @@ MARKUPLM_INPUTS_DOCSTRING = r"""
MARKUPLM_START_DOCSTRING,
)
class MarkupLMModel(MarkupLMPreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->MarkupLM
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->MarkupLM
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config

View File

@ -368,11 +368,18 @@ class RealmSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Realm
REALM_SELF_ATTENTION_CLASSES = {
"eager": RealmSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Realm,BERT->REALM
class RealmAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = RealmSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = REALM_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = RealmSelfOutput(config)
self.pruned_heads = set()

View File

@ -294,11 +294,18 @@ class RobertaSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
ROBERTA_SELF_ATTENTION_CLASSES = {
"eager": RobertaSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA
class RobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = RobertaSelfOutput(config)
self.pruned_heads = set()
@ -688,7 +695,7 @@ class RobertaModel(RobertaPreTrainedModel):
"""
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Roberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -721,7 +728,7 @@ class RobertaModel(RobertaPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -431,11 +431,18 @@ class RoCBertSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert
ROC_BERT_SELF_ATTENTION_CLASSES = {
"eager": RoCBertSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT
class RoCBertAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = RoCBertSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = RoCBertSelfOutput(config)
self.pruned_heads = set()
@ -759,7 +766,6 @@ class RoCBertOnlyMLMHead(nn.Module):
return prediction_scores
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert->RoCBert,bert->roc_bert
class RoCBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -880,7 +886,7 @@ class RoCBertModel(RoCBertPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->RoCBert
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->RoCBert
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config

View File

@ -245,11 +245,18 @@ class SplinterSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
SPLINTER_SELF_ATTENTION_CLASSES = {
"eager": SplinterSelfAttention,
}
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
class SplinterAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = SplinterSelfOutput(config)
self.pruned_heads = set()

View File

@ -295,11 +295,18 @@ class XLMRobertaSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta
XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
"eager": XLMRobertaSelfAttention,
}
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta,ROBERTA->XLM_ROBERTA
class XLMRobertaAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
config, position_embedding_type=position_embedding_type
)
self.output = XLMRobertaSelfOutput(config)
self.pruned_heads = set()
@ -690,7 +697,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
"""
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -723,7 +730,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -664,7 +664,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
"""
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRobertaXL
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
@ -697,7 +697,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,

View File

@ -783,7 +783,7 @@ class XmodModel(XmodPreTrainedModel):
"""
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Xmod
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Xmod
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config

View File

@ -18,7 +18,14 @@ import unittest
from transformers import BertConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_accelerator, slow, torch_device
from transformers.testing_utils import (
CaptureLogger,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -621,6 +628,79 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
# This test was copied from the common test_eager_matches_sdpa_generate(), but without low_cpu_mem_usage=True.
# TODO: Remove this and use the parent method (in common tests) once BERT supports low_cpu_mem_usage=True.
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
max_new_tokens = 30
if len(self.all_generative_model_classes) == 0:
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
for model_class in self.all_generative_model_classes:
if not model_class._supports_sdpa:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
model_sdpa = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
).to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
attn_implementation="eager",
).to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")
# Just test that a large cache works as expected
res_eager = model_eager.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
res_sdpa = model_sdpa.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
@require_torch
class BertModelIntegrationTest(unittest.TestCase):

View File

@ -3603,12 +3603,14 @@ class ModelTesterMixin:
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
@ -3691,19 +3693,21 @@ class ModelTesterMixin:
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
other_inputs = {
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
other_inputs = {
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
processed_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
@ -3712,8 +3716,9 @@ class ModelTesterMixin:
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
@ -3799,6 +3804,7 @@ class ModelTesterMixin:
self.skipTest(f"{model_class.__name__} does not support SDPA")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
if config.model_type in ["llava", "llava_next", "vipllava"]:
self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input")
if config.model_type in ["idefics"]:
@ -3867,12 +3873,14 @@ class ModelTesterMixin:
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa: