From dfa7b580e9863c38c2f0e0dedf0958c2eab9cb48 Mon Sep 17 00:00:00 2001 From: "JB (Don)" <1557853+hackyon@users.noreply.github.com> Date: Fri, 26 Apr 2024 23:23:44 +0800 Subject: [PATCH] [`BERT`] Add support for sdpa (#28802) * Adding SDPA support for BERT * Using the proper input name for testing model input in inference() * Adding documentation for SDPA in BERT model page * Use the stable link for the documentation * Adding a gate to only call .contiguous() for torch < 2.2.0 * Additions and fixes to the documentation * Minor updates to documentation * Adding extra requirements needed for the contiguous() bug * Adding "Adapted from" in plcae of the "Copied from" * Add benchmark speedup tables to the documentation * Minor fixes to the documentation * Use ClapText as a replacemenet for Bert in the Copied-From * Some more fixes for the fix-copies references * Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage [test all] * Undo changes to separate test * Refactored SDPA self attention code for KV projections * Change use_sdpa to attn_implementation * Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models) --- docs/source/en/model_doc/bert.md | 47 +++++ docs/source/en/perf_infer_gpu_one.md | 12 +- src/transformers/modeling_attn_mask_utils.py | 6 +- .../models/align/modeling_align.py | 11 +- .../models/altclip/modeling_altclip.py | 15 +- src/transformers/models/bert/modeling_bert.py | 177 ++++++++++++++++-- .../modeling_bert_generation.py | 11 +- .../bridgetower/modeling_bridgetower.py | 11 +- .../models/camembert/modeling_camembert.py | 15 +- .../chinese_clip/modeling_chinese_clip.py | 11 +- src/transformers/models/clap/modeling_clap.py | 13 +- .../models/data2vec/modeling_data2vec_text.py | 13 +- .../models/electra/modeling_electra.py | 11 +- .../models/ernie/modeling_ernie.py | 13 +- src/transformers/models/git/modeling_git.py | 11 +- .../models/layoutlm/modeling_layoutlm.py | 11 +- .../models/markuplm/modeling_markuplm.py | 13 +- .../models/realm/modeling_realm.py | 11 +- .../models/roberta/modeling_roberta.py | 15 +- .../models/roc_bert/modeling_roc_bert.py | 14 +- .../models/splinter/modeling_splinter.py | 11 +- .../xlm_roberta/modeling_xlm_roberta.py | 15 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 4 +- src/transformers/models/xmod/modeling_xmod.py | 2 +- tests/models/bert/test_modeling_bert.py | 82 +++++++- tests/test_modeling_common.py | 26 ++- 26 files changed, 495 insertions(+), 86 deletions(-) diff --git a/docs/source/en/model_doc/bert.md b/docs/source/en/model_doc/bert.md index c77a1d8525..b6e99d1031 100644 --- a/docs/source/en/model_doc/bert.md +++ b/docs/source/en/model_doc/bert.md @@ -61,6 +61,53 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o - The model must predict the original sentence, but has a second objective: inputs are two sentences A and B (with a separation token in between). With probability 50%, the sentences are consecutive in the corpus, in the remaining 50% they are not related. The model has to predict if the sentences are consecutive or not. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import BertModel + +model = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa") +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (A100-80GB, CPUx12, RAM 96.6GB, PyTorch 2.2.0, OS Ubuntu 22.04) with `float16`, we saw the +following speedups during training and inference. + +#### Training + +|batch_size|seq_len|Time per batch (eager - s)|Time per batch (sdpa - s)|Speedup (%)|Eager peak mem (MB)|sdpa peak mem (MB)|Mem saving (%)| +|----------|-------|--------------------------|-------------------------|-----------|-------------------|------------------|--------------| +|4 |256 |0.023 |0.017 |35.472 |939.213 |764.834 |22.800 | +|4 |512 |0.023 |0.018 |23.687 |1970.447 |1227.162 |60.569 | +|8 |256 |0.023 |0.018 |23.491 |1594.295 |1226.114 |30.028 | +|8 |512 |0.035 |0.025 |43.058 |3629.401 |2134.262 |70.054 | +|16 |256 |0.030 |0.024 |25.583 |2874.426 |2134.262 |34.680 | +|16 |512 |0.064 |0.044 |46.223 |6964.659 |3961.013 |75.830 | + +#### Inference + +|batch_size|seq_len|Per token latency eager (ms)|Per token latency SDPA (ms)|Speedup (%)|Mem eager (MB)|Mem BT (MB)|Mem saved (%)| +|----------|-------|----------------------------|---------------------------|-----------|--------------|-----------|-------------| +|1 |128 |5.736 |4.987 |15.022 |282.661 |282.924 |-0.093 | +|1 |256 |5.689 |4.945 |15.055 |298.686 |298.948 |-0.088 | +|2 |128 |6.154 |4.982 |23.521 |314.523 |314.785 |-0.083 | +|2 |256 |6.201 |4.949 |25.303 |347.546 |347.033 |0.148 | +|4 |128 |6.049 |4.987 |21.305 |378.895 |379.301 |-0.107 | +|4 |256 |6.285 |5.364 |17.166 |443.209 |444.382 |-0.264 | + + + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BERT. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 494ba660fa..64583e4bad 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -187,10 +187,11 @@ FlashAttention is more memory efficient, meaning you can train on much larger se ## PyTorch scaled dot product attention -PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. +PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. For now, Transformers supports SDPA inference and training for the following architectures: * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) +* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) @@ -224,6 +225,13 @@ FlashAttention can only be used for models with the `fp16` or `bf16` torch type, + + +SDPA does not support certain sets of attention parameters, such as `head_mask` and `output_attentions=True`. +In that case, you should see a warning message and we will fall back to the (slower) eager implementation. + + + By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: ```diff @@ -232,8 +240,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda") -# convert the model to BetterTransformer -model.to_bettertransformer() input_text = "Hello my dog is cute and" inputs = tokenizer(input_text, return_tensors="pt").to("cuda") diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index c69d9555b2..44ea179566 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -445,10 +445,8 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) - if torch.all(mask == 1): - if is_tracing: - pass - elif tgt_len == 1: + if not is_tracing and torch.all(mask == 1): + if tgt_len == 1: # For query_length == 1, causal attention and bi-directional attention are the same. return None elif key_value_length == tgt_len: diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 3dce9d383d..0f8246e8f9 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -883,11 +883,18 @@ class AlignTextSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText +ALIGN_TEXT_SELF_ATTENTION_CLASSES = { + "eager": AlignTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT class AlignTextAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = AlignTextSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = AlignTextSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 0d27d87de7..ba8abb311a 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -434,11 +434,18 @@ class AltRobertaSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta +ALT_ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": AltRobertaSelfAttention, +} + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA class AltRobertaAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = AltRobertaSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = AltRobertaSelfOutput(config) self.pruned_heads = set() @@ -1205,7 +1212,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel): config_class = AltCLIPTextConfig - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->AltRoberta + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->AltRoberta def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -1232,7 +1239,7 @@ class AltRobertaModel(AltCLIPPreTrainedModel): for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 262fc79f0d..f7af0f1ef5 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -23,10 +23,15 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint +from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -45,6 +50,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + get_torch_version, logging, replace_return_docstrings, ) @@ -350,6 +356,103 @@ class BertSelfAttention(nn.Module): return outputs +class BertSdpaSelfAttention(BertSelfAttention): + def __init__(self, config, position_embedding_type=None): + super().__init__(config, position_embedding_type=position_embedding_type) + self.dropout_prob = config.attention_probs_dropout_prob + self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") + + # Adapted from BertSelfAttention + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. + logger.warning_once( + "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " + "the manual attention implementation, but specifying the manual implementation will be required from " + "Transformers version v5.0.0 onwards. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal + # mask in case tgt_len == 1. + is_causal = self.is_decoder and attention_mask is None and tgt_len > 1 + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -364,10 +467,18 @@ class BertSelfOutput(nn.Module): return hidden_states +BERT_SELF_ATTENTION_CLASSES = { + "eager": BertSelfAttention, + "sdpa": BertSdpaSelfAttention, +} + + class BertAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = BertSelfOutput(config) self.pruned_heads = set() @@ -715,6 +826,7 @@ class BertPreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" supports_gradient_checkpointing = True + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" @@ -859,6 +971,9 @@ class BertModel(BertPreTrainedModel): self.pooler = BertPooler(config) if add_pooling_layer else None + self.attn_implementation = config._attn_implementation + self.position_embedding_type = config.position_embedding_type + # Initialize weights and apply final processing self.post_init() @@ -945,9 +1060,6 @@ class BertModel(BertPreTrainedModel): # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) - if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] @@ -956,9 +1068,43 @@ class BertModel(BertPreTrainedModel): else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -967,7 +1113,15 @@ class BertModel(BertPreTrainedModel): encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + + if use_sdpa_attention_masks: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -978,13 +1132,6 @@ class BertModel(BertPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index b7250f6f7b..73c4d1d1e5 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -192,11 +192,18 @@ class BertGenerationSelfAttention(nn.Module): return outputs -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration +BERT_GENERATION_SELF_ATTENTION_CLASSES = { + "eager": BertGenerationSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BertGeneration,BERT->BERT_GENERATION class BertGenerationAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = BertGenerationSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = BERT_GENERATION_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = BertGenerationSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index bcace39b29..3fc9f755aa 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -562,11 +562,18 @@ class BridgeTowerSelfAttention(nn.Module): return outputs -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower +BRIDGE_TOWER_SELF_ATTENTION_CLASSES = { + "eager": BridgeTowerSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER class BridgeTowerAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = BridgeTowerSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = BRIDGE_TOWER_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = BridgeTowerSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 26250896b2..f399fb3f5c 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -312,11 +312,18 @@ class CamembertSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert +CAMEMBERT_SELF_ATTENTION_CLASSES = { + "eager": CamembertSelfAttention, +} + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->Camembert,ROBERTA->CAMEMBERT class CamembertAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = CamembertSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = CAMEMBERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = CamembertSelfOutput(config) self.pruned_heads = set() @@ -745,7 +752,7 @@ class CamembertModel(CamembertPreTrainedModel): _no_split_modules = [] - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Camembert def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -778,7 +785,7 @@ class CamembertModel(CamembertPreTrainedModel): output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index d8e97c20b2..87a1baa217 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -354,11 +354,18 @@ class ChineseCLIPTextSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText +CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = { + "eager": ChineseCLIPTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT class ChineseCLIPTextAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = ChineseCLIPTextSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = ChineseCLIPTextSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 7b20b30137..c21e173133 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1376,11 +1376,18 @@ class ClapTextSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText +CLAP_TEXT_SELF_ATTENTION_CLASSES = { + "eager": ClapTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT class ClapTextAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = ClapTextSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = ClapTextSelfOutput(config) self.pruned_heads = set() @@ -1763,7 +1770,6 @@ class ClapTextModel(ClapPreTrainedModel): config_class = ClapTextConfig - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->ClapText def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -1782,7 +1788,6 @@ class ClapTextModel(ClapPreTrainedModel): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - # Copied from transformers.models.bert.modeling_bert.BertModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 7dcc53e2cc..20e1e1eca5 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -298,11 +298,18 @@ class Data2VecTextSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText +DATA2VEC_TEXT_SELF_ATTENTION_CLASSES = { + "eager": Data2VecTextSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Data2VecText,BERT->DATA2VEC_TEXT class Data2VecTextAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = Data2VecTextSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = DATA2VEC_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = Data2VecTextSelfOutput(config) self.pruned_heads = set() @@ -727,7 +734,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel): output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 2138aa97c6..6fbdda2579 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -355,11 +355,18 @@ class ElectraSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra +ELECTRA_SELF_ATTENTION_CLASSES = { + "eager": ElectraSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA class ElectraAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = ElectraSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index a65f453205..3db6501985 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -285,11 +285,18 @@ class ErnieSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie +ERNIE_SELF_ATTENTION_CLASSES = { + "eager": ErnieSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE class ErnieAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = ErnieSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = ErnieSelfOutput(config) self.pruned_heads = set() @@ -787,7 +794,7 @@ class ErnieModel(ErniePreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Ernie + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Ernie def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index c8953d4984..12821609f0 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -267,11 +267,18 @@ class GitSelfOutput(nn.Module): return hidden_states +GIT_SELF_ATTENTION_CLASSES = { + "eager": GitSelfAttention, +} + + class GitAttention(nn.Module): - # Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git + # Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git,BERT->GIT def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = GitSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index c570fdb124..6914f5ee3e 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -276,11 +276,18 @@ class LayoutLMSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM +LAYOUTLM_SELF_ATTENTION_CLASSES = { + "eager": LayoutLMSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM class LayoutLMAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = LayoutLMSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = LayoutLMSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 2058ce2795..318110daf5 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -468,11 +468,18 @@ class MarkupLMSelfAttention(nn.Module): return outputs -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM +MARKUPLM_SELF_ATTENTION_CLASSES = { + "eager": MarkupLMSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM class MarkupLMAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = MarkupLMSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = MarkupLMSelfOutput(config) self.pruned_heads = set() @@ -797,7 +804,7 @@ MARKUPLM_INPUTS_DOCSTRING = r""" MARKUPLM_START_DOCSTRING, ) class MarkupLMModel(MarkupLMPreTrainedModel): - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->MarkupLM + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->MarkupLM def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 86f2894289..adec5647a2 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -368,11 +368,18 @@ class RealmSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Realm +REALM_SELF_ATTENTION_CLASSES = { + "eager": RealmSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Realm,BERT->REALM class RealmAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = RealmSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = REALM_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = RealmSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index e1f15722e4..6401392120 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -294,11 +294,18 @@ class RobertaSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta +ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": RobertaSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA class RobertaAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = RobertaSelfOutput(config) self.pruned_heads = set() @@ -688,7 +695,7 @@ class RobertaModel(RobertaPreTrainedModel): """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Roberta def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -721,7 +728,7 @@ class RobertaModel(RobertaPreTrainedModel): output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 51850c9af1..739e60b550 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -431,11 +431,18 @@ class RoCBertSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert +ROC_BERT_SELF_ATTENTION_CLASSES = { + "eager": RoCBertSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RoCBert,BERT->ROC_BERT class RoCBertAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = RoCBertSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = ROC_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = RoCBertSelfOutput(config) self.pruned_heads = set() @@ -759,7 +766,6 @@ class RoCBertOnlyMLMHead(nn.Module): return prediction_scores -# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert->RoCBert,bert->roc_bert class RoCBertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -880,7 +886,7 @@ class RoCBertModel(RoCBertPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->RoCBert + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->RoCBert def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index b643601d0e..fa546e1201 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -245,11 +245,18 @@ class SplinterSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter +SPLINTER_SELF_ATTENTION_CLASSES = { + "eager": SplinterSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER class SplinterAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = SplinterSelfOutput(config) self.pruned_heads = set() diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 0d829aaee6..48c6898811 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -295,11 +295,18 @@ class XLMRobertaSelfOutput(nn.Module): return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta +XLM_ROBERTA_SELF_ATTENTION_CLASSES = { + "eager": XLMRobertaSelfAttention, +} + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->XLMRoberta,ROBERTA->XLM_ROBERTA class XLMRobertaAttention(nn.Module): def __init__(self, config, position_embedding_type=None): super().__init__() - self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type) + self.self = XLM_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) self.output = XLMRobertaSelfOutput(config) self.pruned_heads = set() @@ -690,7 +697,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRoberta + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRoberta def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -723,7 +730,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel): output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 1c17652dfa..d8994e335b 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -664,7 +664,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): an input to the forward pass. .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->XLMRobertaXL + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->XLMRobertaXL def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config @@ -697,7 +697,7 @@ class XLMRobertaXLModel(XLMRobertaXLPreTrainedModel): output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) - # Copied from transformers.models.bert.modeling_bert.BertModel.forward + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward def forward( self, input_ids: Optional[torch.Tensor] = None, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 2bf76a40d4..32e34ef668 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -783,7 +783,7 @@ class XmodModel(XmodPreTrainedModel): """ - # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Xmod + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Xmod def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index bdc812ff27..ff9a628020 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -18,7 +18,14 @@ import unittest from transformers import BertConfig, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import CaptureLogger, require_torch, require_torch_accelerator, slow, torch_device +from transformers.testing_utils import ( + CaptureLogger, + require_torch, + require_torch_accelerator, + require_torch_sdpa, + slow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -621,6 +628,79 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + # This test was copied from the common test_eager_matches_sdpa_generate(), but without low_cpu_mem_usage=True. + # TODO: Remove this and use the parent method (in common tests) once BERT supports low_cpu_mem_usage=True. + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + max_new_tokens = 30 + + if len(self.all_generative_model_classes) == 0: + self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model_sdpa = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + ).to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + attn_implementation="eager", + ).to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + + # Just test that a large cache works as expected + res_eager = model_eager.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + res_sdpa = model_sdpa.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False + ) + + self.assertTrue(torch.allclose(res_eager, res_sdpa)) + @require_torch class BertModelIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1c099a4035..061c0000ce 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3603,12 +3603,14 @@ class ModelTesterMixin: self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") has_sdpa = False for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: has_sdpa = True break if not has_sdpa and model_sdpa.config.model_type != "falcon": @@ -3691,19 +3693,21 @@ class ModelTesterMixin: decoder_input_ids = decoder_input_ids.to(torch_device) # TODO: never an `attention_mask` arg here? - other_inputs = { + processed_inputs = { + model.main_input_name: dummy_input, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": dummy_attention_mask, "output_hidden_states": True, } else: - other_inputs = { + processed_inputs = { + model.main_input_name: dummy_input, "output_hidden_states": True, } # Otherwise fails for e.g. WhisperEncoderModel if "attention_mask" in inspect.signature(model_eager.forward).parameters: - other_inputs["attention_mask"] = dummy_attention_mask + processed_inputs["attention_mask"] = dummy_attention_mask # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): @@ -3712,8 +3716,9 @@ class ModelTesterMixin: enable_math=True, enable_mem_efficient=enable_kernels, ): - outputs_eager = model_eager(dummy_input, **other_inputs) - outputs_sdpa = model_sdpa(dummy_input, **other_inputs) + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) logits_eager = ( outputs_eager.hidden_states[-1] @@ -3799,6 +3804,7 @@ class ModelTesterMixin: self.skipTest(f"{model_class.__name__} does not support SDPA") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) if config.model_type in ["llava", "llava_next", "vipllava"]: self.skipTest("Llava-like models currently (transformers==4.39.1) requires an attention_mask input") if config.model_type in ["idefics"]: @@ -3867,12 +3873,14 @@ class ModelTesterMixin: self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") has_sdpa = False for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: has_sdpa = True break if not has_sdpa: