[`FA2`] Add flash attention for for `DistilBert` (#26489)
* flash attention added for DistilBert * fixes * removed padding_masks * Update modeling_distilbert.py * Update test_modeling_distilbert.py * style fix
This commit is contained in:
parent
5964f820db
commit
1ac2463dfe
|
@ -133,6 +133,37 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
|||
- A blog post on how to [deploy DistilBERT with Amazon SageMaker](https://huggingface.co/blog/deploy-hugging-face-models-easily-with-amazon-sagemaker).
|
||||
- A blog post on how to [Deploy BERT with Hugging Face Transformers, Amazon SageMaker and Terraform module](https://www.philschmid.de/terraform-huggingface-amazon-sagemaker).
|
||||
|
||||
|
||||
## Combining DistilBERT and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
||||
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
|
||||
|
||||
>>> text = "Replace me by any text you'd like."
|
||||
|
||||
>>> encoded_input = tokenizer(text, return_tensors='pt').to(device)
|
||||
>>> model.to(device)
|
||||
|
||||
>>> output = model(**encoded_input)
|
||||
```
|
||||
|
||||
|
||||
## DistilBertConfig
|
||||
|
||||
[[autodoc]] DistilBertConfig
|
||||
|
|
|
@ -24,6 +24,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
|
@ -44,12 +45,18 @@ from ...utils import (
|
|||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
|
||||
_CONFIG_FOR_DOC = "DistilBertConfig"
|
||||
|
@ -69,6 +76,19 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
@ -141,10 +161,12 @@ class Embeddings(nn.Module):
|
|||
class MultiHeadSelfAttention(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.n_heads = config.n_heads
|
||||
self.dim = config.dim
|
||||
self.dropout = nn.Dropout(p=config.attention_dropout)
|
||||
self.is_causal = False
|
||||
|
||||
# Have an even number of multi heads that divide the dimensions
|
||||
if self.dim % self.n_heads != 0:
|
||||
|
@ -240,6 +262,178 @@ class MultiHeadSelfAttention(nn.Module):
|
|||
return (context,)
|
||||
|
||||
|
||||
class DistilBertFlashAttention2(MultiHeadSelfAttention):
|
||||
"""
|
||||
DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
|
||||
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
|
||||
API of flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Parameters:
|
||||
query: torch.tensor(bs, seq_length, dim)
|
||||
key: torch.tensor(bs, seq_length, dim)
|
||||
value: torch.tensor(bs, seq_length, dim)
|
||||
mask: torch.tensor(bs, seq_length)
|
||||
|
||||
Returns:
|
||||
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
|
||||
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
|
||||
"""
|
||||
batch_size, q_length, dim = query.size()
|
||||
|
||||
dim_per_head = self.dim // self.n_heads
|
||||
|
||||
def reshape(x: torch.Tensor) -> torch.Tensor:
|
||||
"""separate heads"""
|
||||
return x.view(batch_size, -1, self.n_heads, dim_per_head)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
query_states = reshape(self.q_lin(query))
|
||||
key_states = reshape(self.k_lin(key))
|
||||
value_states = reshape(self.v_lin(value))
|
||||
|
||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
if query_states.dtype == torch.float32:
|
||||
# Handle the case where the model is quantized
|
||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_lin.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_weights = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, mask, q_length, dropout=attn_dropout
|
||||
)
|
||||
|
||||
attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
|
||||
attn_output = self.out_lin(attn_weights_reshaped)
|
||||
|
||||
if output_attentions:
|
||||
return (attn_output, attn_weights)
|
||||
else:
|
||||
return (attn_output,)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward with causal=True->causal=False
|
||||
def _flash_attention_forward(
|
||||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`int`, *optional*):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=self.is_causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->n_heads
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
@ -269,7 +463,11 @@ class TransformerBlock(nn.Module):
|
|||
if config.dim % config.n_heads != 0:
|
||||
raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
|
||||
|
||||
self.attention = MultiHeadSelfAttention(config)
|
||||
self.attention = (
|
||||
MultiHeadSelfAttention(config)
|
||||
if not getattr(config, "_flash_attn_2_enabled", False)
|
||||
else DistilBertFlashAttention2(config)
|
||||
)
|
||||
self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
|
||||
|
||||
self.ffn = FFN(config)
|
||||
|
@ -407,6 +605,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
|||
load_tf_weights = None
|
||||
base_model_prefix = "distilbert"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
|
@ -588,14 +787,17 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||
|
||||
# Prepare head mask if needed
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
|
||||
|
||||
if getattr(self.config, "_flash_attn_2_enabled", False):
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
|
||||
|
||||
return self.transformer(
|
||||
x=embeddings,
|
||||
attn_mask=attention_mask,
|
||||
|
|
|
@ -16,8 +16,10 @@ import os
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
from pytest import mark
|
||||
|
||||
from transformers import DistilBertConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
|
||||
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
@ -285,6 +287,114 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||
|
||||
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
|
||||
@require_flash_attn
|
||||
@require_torch_accelerator
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
dummy_input = torch.LongTensor(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 8, 9],
|
||||
[1, 2, 11, 12],
|
||||
[1, 2, 13, 14],
|
||||
]
|
||||
).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[0, 1, 1, 1],
|
||||
[0, 1, 1, 1],
|
||||
[0, 1, 1, 1],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
|
||||
|
||||
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits_fa = output_fa.hidden_states[-1]
|
||||
|
||||
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits = output.hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
|
||||
|
||||
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
|
||||
@require_flash_attn
|
||||
@require_torch_accelerator
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
dummy_input = torch.LongTensor(
|
||||
[
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 8, 9],
|
||||
[1, 2, 11, 12],
|
||||
[1, 2, 13, 14],
|
||||
]
|
||||
).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor(
|
||||
[
|
||||
[0, 1, 1, 1],
|
||||
[0, 1, 1, 1],
|
||||
[0, 1, 1, 1],
|
||||
[0, 1, 1, 1],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
|
||||
|
||||
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits_fa = output_fa.hidden_states[-1]
|
||||
|
||||
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits = output.hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
|
||||
|
||||
|
||||
@require_torch
|
||||
class DistilBertModelIntergrationTest(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue