[`FA2`] Add flash attention for opt (#26414)

* added flash attention for opt

* added to list

* fix use cache (#3)

* style fix

* fix text

* test fix2

* reverted until 689f599

* torch fx tests are working now!

* small fix

* added TODO docstring

* changes

* comments and .md file modification

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Susnato Dhar 2023-11-23 15:46:51 +05:30 committed by GitHub
parent 1ddc4fa60e
commit 3bc50d81e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 323 additions and 30 deletions

View File

@ -62,6 +62,55 @@ The resource should ideally demonstrate something new instead of duplicating an
- A blog post on [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models) with OPT.
## Combining OPT and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.
```bash
pip install -U flash-attn --no-build-isolation
```
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
To load and run a model using Flash Attention 2, refer to the snippet below:
```python
>>> import torch
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
"Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived "
"there?")
>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)
>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False)
>>> tokenizer.batch_decode(generated_ids)[0]
'</s>A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: I have lived here for about a year.\nHuman: What is your favorite place to eat?\nStatue: I love'
```
### Expected speedups
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `facebook/opt-2.7b` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.
<div style="text-align: center">
<img src="https://user-images.githubusercontent.com/49240599/281101546-d2fca6d2-ee44-48f3-9534-ba8d5bee4531.png">
</div>
Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `facebook/opt-350m` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths.
<div style="text-align: center">
<img src="https://user-images.githubusercontent.com/49240599/281101682-d1144e90-0dbc-46f4-8fc8-c6206cb793c9.png">
</div>
## OPTConfig
[[autodoc]] OPTConfig

View File

@ -16,6 +16,7 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -33,12 +34,18 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
from .configuration_opt import OPTConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
@ -64,6 +71,19 @@ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
class OPTLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
@ -93,30 +113,49 @@ class OPTAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
config: OPTConfig,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
"""
If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
warning and return that value, otherwise take the equivalent config.config_arg_name
"""
val = None
if fn_arg_name in kwargs:
logging.warning(
"Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38."
" Please set it in the config instead"
)
val = kwargs.pop(fn_arg_name)
else:
val = getattr(config, config_arg_name)
return val
self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs)
self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs)
self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs)
self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs)
self.head_dim = self.embed_dim // self.num_heads
self.is_causal = True
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
f" and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -242,17 +281,210 @@ class OPTAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
class OptFlashAttention2(OPTAttention):
"""
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of flash
attention and deal with padding tokens in case the input contains any of them.
"""
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, _, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
query_length = query_states.shape[1]
tgt_len = key_states.shape[-2]
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
attn_dropout = self.dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout
)
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
attn_output = self.out_proj(attn_weights_reshaped)
if not output_attentions:
attn_weights_reshaped = None
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)
if not getattr(config, "_flash_attn_2_enabled", False):
self.self_attn = OPTAttention(config=config, is_decoder=True)
else:
self.self_attn = OptFlashAttention2(config=config, is_decoder=True)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
@ -368,6 +600,7 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
def _init_weights(self, module):
std = self.config.init_std
@ -581,16 +814,27 @@ class OPTDecoder(OPTPreTrainedModel):
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if attention_mask is None
else attention_mask
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None: