[`BERT`] Add support for sdpa (#28802)
* Adding SDPA support for BERT * Using the proper input name for testing model input in inference() * Adding documentation for SDPA in BERT model page * Use the stable link for the documentation * Adding a gate to only call .contiguous() for torch < 2.2.0 * Additions and fixes to the documentation * Minor updates to documentation * Adding extra requirements needed for the contiguous() bug * Adding "Adapted from" in plcae of the "Copied from" * Add benchmark speedup tables to the documentation * Minor fixes to the documentation * Use ClapText as a replacemenet for Bert in the Copied-From * Some more fixes for the fix-copies references * Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage [test all] * Undo changes to separate test * Refactored SDPA self attention code for KV projections * Change use_sdpa to attn_implementation * Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models)
This commit is contained in:
parent
2de5cb12be
commit
dfa7b580e9
|
@ -61,6 +61,53 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
|
|||
|
||||
- The model must predict the original sentence, but has a second objective: inputs are two sentences A and B (with a separation token in between). With probability 50%, the sentences are consecutive in the corpus, in the remaining 50% they are not related. The model has to predict if the sentences are consecutive or not.
|
||||
|
||||
### Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
|
||||
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
```
|
||||
from transformers import BertModel
|
||||
|
||||
model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
|
||||
...
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
On a local benchmark (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04) with `float16`, we saw the
|
||||
following speedups during training and inference.
|
||||
|
||||
#### Training
|
||||
|
||||
|batch_size|seq_len|Time per batch (eager - s)|Time per batch (sdpa - s)|Speedup (%)|Eager peak mem (MB)|sdpa peak mem (MB)|Mem saving (%)|
|
||||
|----------|-------|--------------------------|-------------------------|-----------|-------------------|------------------|--------------|
|
||||
|4 |256 |0.023 |0.017 |35.472 |939.213 |764.834 |22.800 |
|
||||
|4 |512 |0.023 |0.018 |23.687 |1970.447 |1227.162 |60.569 |
|
||||
|8 |256 |0.023 |0.018 |23.491 |1594.295 |1226.114 |30.028 |
|
||||
|8 |512 |0.035 |0.025 |43.058 |3629.401 |2134.262 |70.054 |
|
||||
|16 |256 |0.030 |0.024 |25.583 |2874.426 |2134.262 |34.680 |
|
||||
|16 |512 |0.064 |0.044 |46.223 |6964.659 |3961.013 |75.830 |
|
||||
|
||||
#### Inference
|
||||
|
||||
|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%)|Mem eager (MB)|Mem BT (MB)|Mem saved (%)|
|
||||
|----------|-------|----------------------------|---------------------------|-----------|--------------|-----------|-------------|
|
||||
|1 |128 |5.736 |4.987 |15.022 |282.661 |282.924 |-0.093 |
|
||||
|1 |256 |5.689 |4.945 |15.055 |298.686 |298.948 |-0.088 |
|
||||
|2 |128 |6.154 |4.982 |23.521 |314.523 |314.785 |-0.083 |
|
||||
|2 |256 |6.201 |4.949 |25.303 |347.546 |347.033 |0.148 |
|
||||
|4 |128 |6.049 |4.987 |21.305 |378.895 |379.301 |-0.107 |
|
||||
|4 |256 |6.285 |5.364 |17.166 |443.209 |444.382 |-0.264 |
|
||||
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
|
|
@ -187,10 +187,11 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
|
|||
|
||||
## PyTorch scaled dot product attention
|
||||
|
||||
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available.
|
||||
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
|
||||
|
||||
For now, Transformers supports SDPA inference and training for the following architectures:
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
|
||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
|
@ -224,6 +225,13 @@ FlashAttention can only be used for models with the `fp16` or `bf16` torch type,
|
|||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
SDPA does not support certain sets of attention parameters, such as `head_mask` and `output_attentions=True`.
|
||||
In that case, you should see a warning message and we will fall back to the (slower) eager implementation.
|
||||
|
||||
</Tip>
|
||||
|
||||
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
|
||||
|
||||
```diff
|
||||
|
@ -232,8 +240,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
|
||||
# convert the model to BetterTransformer
|
||||
model.to_bettertransformer()
|
||||
|
||||
input_text = "Hello my dog is cute and"
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
|
|
@ -445,10 +445,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
|
|||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
|
||||
if torch.all(mask == 1):
|
||||
if is_tracing:
|
||||
pass
|
||||
elif tgt_len == 1:
|
||||
if not is_tracing and torch.all(mask == 1):
|
||||
if tgt_len == 1:
|
||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||
return None
|
||||
elif key_value_length == tgt_len:
|
||||
|
|
|
@ -883,11 +883,18 @@ class AlignTextSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText
|
||||
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": AlignTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
|
||||
class AlignTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = AlignTextSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = AlignTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -434,11 +434,18 @@ class AltRobertaSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta
|
||||
ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
"eager": AltRobertaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
|
||||
class AltRobertaAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = AltRobertaSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -1205,7 +1212,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
|||
|
||||
config_class = AltCLIPTextConfig
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->AltRoberta
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->AltRoberta
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
@ -1232,7 +1239,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
|||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -23,10 +23,15 @@ from typing import List, Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
|
@ -45,6 +50,7 @@ from ...utils import (
|
|||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
get_torch_version,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
@ -350,6 +356,103 @@ class BertSelfAttention(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
class BertSdpaSelfAttention(BertSelfAttention):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__(config, position_embedding_type=position_embedding_type)
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
||||
|
||||
# Adapted from BertSelfAttention
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
|
||||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
||||
logger.warning_once(
|
||||
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
|
||||
"the manual attention implementation, but specifying the manual implementation will be required from "
|
||||
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
|
||||
'`attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal
|
||||
# mask in case tgt_len == 1.
|
||||
is_causal = self.is_decoder and attention_mask is None and tgt_len > 1
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -364,10 +467,18 @@ class BertSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
BERT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": BertSelfAttention,
|
||||
"sdpa": BertSdpaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -715,6 +826,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -859,6 +971,9 @@ class BertModel(BertPreTrainedModel):
|
|||
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.attn_implementation = config._attn_implementation
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
@ -945,9 +1060,6 @@ class BertModel(BertPreTrainedModel):
|
|||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
|
@ -956,9 +1068,43 @@ class BertModel(BertPreTrainedModel):
|
|||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
||||
|
||||
use_sdpa_attention_masks = (
|
||||
self.attn_implementation == "sdpa"
|
||||
and self.position_embedding_type == "absolute"
|
||||
and head_mask is None
|
||||
and not output_attentions
|
||||
)
|
||||
|
||||
# Expand the attention mask
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
if self.config.is_decoder:
|
||||
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
input_shape,
|
||||
embedding_output,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
@ -967,7 +1113,15 @@ class BertModel(BertPreTrainedModel):
|
|||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
if use_sdpa_attention_masks:
|
||||
# Expand the attention mask for SDPA.
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
||||
)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
|
@ -978,13 +1132,6 @@ class BertModel(BertPreTrainedModel):
|
|||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
|
|
|
@ -192,11 +192,18 @@ class BertGenerationSelfAttention(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration
|
||||
BERT_GENERATION_SELF_ATTENTION_CLASSES = {
|
||||
"eager": BertGenerationSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION
|
||||
class BertGenerationAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = BertGenerationSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = BERT_GENERATION_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = BertGenerationSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -562,11 +562,18 @@ class BridgeTowerSelfAttention(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower
|
||||
BRIDGE_TOWER_SELF_ATTENTION_CLASSES = {
|
||||
"eager": BridgeTowerSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
|
||||
class BridgeTowerAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = BridgeTowerSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = BridgeTowerSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -312,11 +312,18 @@ class CamembertSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert
|
||||
CAMEMBERT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": CamembertSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT
|
||||
class CamembertAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = CamembertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = CamembertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -745,7 +752,7 @@ class CamembertModel(CamembertPreTrainedModel):
|
|||
|
||||
_no_split_modules = []
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
@ -778,7 +785,7 @@ class CamembertModel(CamembertPreTrainedModel):
|
|||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -354,11 +354,18 @@ class ChineseCLIPTextSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText
|
||||
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ChineseCLIPTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
|
||||
class ChineseCLIPTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = ChineseCLIPTextSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = ChineseCLIPTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -1376,11 +1376,18 @@ class ClapTextSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText
|
||||
CLAP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ClapTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
|
||||
class ClapTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = ClapTextSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = ClapTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -1763,7 +1770,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
|||
|
||||
config_class = ClapTextConfig
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->ClapText
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
@ -1782,7 +1788,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
|||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -298,11 +298,18 @@ class Data2VecTextSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText
|
||||
DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": Data2VecTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT
|
||||
class Data2VecTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = Data2VecTextSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = Data2VecTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -727,7 +734,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
|||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -355,11 +355,18 @@ class ElectraSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra
|
||||
ELECTRA_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ElectraSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA
|
||||
class ElectraAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = ElectraSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -285,11 +285,18 @@ class ErnieSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie
|
||||
ERNIE_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ErnieSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE
|
||||
class ErnieAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = ErnieSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = ErnieSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -787,7 +794,7 @@ class ErnieModel(ErniePreTrainedModel):
|
|||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Ernie
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Ernie
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
|
|
@ -267,11 +267,18 @@ class GitSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
GIT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": GitSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
class GitAttention(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git,BERT->GIT
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = GitSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -276,11 +276,18 @@ class LayoutLMSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM
|
||||
LAYOUTLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": LayoutLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
|
||||
class LayoutLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = LayoutLMSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = LayoutLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -468,11 +468,18 @@ class MarkupLMSelfAttention(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM
|
||||
MARKUPLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": MarkupLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
|
||||
class MarkupLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = MarkupLMSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = MarkupLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -797,7 +804,7 @@ MARKUPLM_INPUTS_DOCSTRING = r"""
|
|||
MARKUPLM_START_DOCSTRING,
|
||||
)
|
||||
class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->MarkupLM
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->MarkupLM
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
|
|
@ -368,11 +368,18 @@ class RealmSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Realm
|
||||
REALM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": RealmSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Realm,BERT->REALM
|
||||
class RealmAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = RealmSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = REALM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = RealmSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -294,11 +294,18 @@ class RobertaSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
|
||||
ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
"eager": RobertaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA
|
||||
class RobertaAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = RobertaSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -688,7 +695,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Roberta
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
@ -721,7 +728,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
|||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -431,11 +431,18 @@ class RoCBertSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert
|
||||
ROC_BERT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": RoCBertSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT
|
||||
class RoCBertAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = RoCBertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = RoCBertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -759,7 +766,6 @@ class RoCBertOnlyMLMHead(nn.Module):
|
|||
return prediction_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert->RoCBert,bert->roc_bert
|
||||
class RoCBertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
|
@ -880,7 +886,7 @@ class RoCBertModel(RoCBertPreTrainedModel):
|
|||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->RoCBert
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->RoCBert
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
|
|
@ -245,11 +245,18 @@ class SplinterSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
|
||||
SPLINTER_SELF_ATTENTION_CLASSES = {
|
||||
"eager": SplinterSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
|
||||
class SplinterAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = SplinterSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
|
|
@ -295,11 +295,18 @@ class XLMRobertaSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta
|
||||
XLM_ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
"eager": XLMRobertaSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta,ROBERTA->XLM_ROBERTA
|
||||
class XLMRobertaAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.output = XLMRobertaSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -690,7 +697,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|||
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
@ -723,7 +730,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -664,7 +664,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
|||
an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRobertaXL
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
@ -697,7 +697,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel):
|
|||
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -783,7 +783,7 @@ class XmodModel(XmodPreTrainedModel):
|
|||
|
||||
"""
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Xmod
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Xmod
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
|
|
@ -18,7 +18,14 @@ import unittest
|
|||
|
||||
from transformers import BertConfig, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_accelerator, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
|
@ -621,6 +628,79 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
|
||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||
|
||||
# This test was copied from the common test_eager_matches_sdpa_generate(), but without low_cpu_mem_usage=True.
|
||||
# TODO: Remove this and use the parent method (in common tests) once BERT supports low_cpu_mem_usage=True.
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
max_new_tokens = 30
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
# low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
# low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# Just test that a large cache works as expected
|
||||
res_eager = model_eager.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
res_sdpa = model_sdpa.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertModelIntegrationTest(unittest.TestCase):
|
||||
|
|
|
@ -3603,12 +3603,14 @@ class ModelTesterMixin:
|
|||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
|
@ -3691,19 +3693,21 @@ class ModelTesterMixin:
|
|||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
# TODO: never an `attention_mask` arg here?
|
||||
other_inputs = {
|
||||
processed_inputs = {
|
||||
model.main_input_name: dummy_input,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
else:
|
||||
other_inputs = {
|
||||
processed_inputs = {
|
||||
model.main_input_name: dummy_input,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
|
||||
# Otherwise fails for e.g. WhisperEncoderModel
|
||||
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
|
@ -3712,8 +3716,9 @@ class ModelTesterMixin:
|
|||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
):
|
||||
outputs_eager = model_eager(dummy_input, **other_inputs)
|
||||
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
|
||||
logits_eager = (
|
||||
outputs_eager.hidden_states[-1]
|
||||
|
@ -3799,6 +3804,7 @@ class ModelTesterMixin:
|
|||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
if config.model_type in ["llava", "llava_next", "vipllava"]:
|
||||
self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input")
|
||||
if config.model_type in ["idefics"]:
|
||||
|
@ -3867,12 +3873,14 @@ class ModelTesterMixin:
|
|||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
|
|
Loading…
Reference in New Issue