fix(phi3): Removes additional flash-attention usage, .e.g, swiglu and rmsnorm.

This commit is contained in:
Gustavo de Rosa 2024-04-23 08:39:27 -07:00
parent 508ec8ef31
commit 56e6464f1a
1 changed files with 9 additions and 36 deletions

View File

@ -40,6 +40,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@ -47,33 +48,13 @@ from ...utils import (
from .configuration_phi3 import Phi3Config
logger = logging.get_logger(__name__)
# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
# if is_flash_attn_2_available():
_flash_supports_window_size = False
try:
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
if not _flash_supports_window_size:
raise ValueError("Please update flash-attention to support window size.")
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn.ops.activations import swiglu
from flash_attn.ops.rms_norm import RMSNorm as Phi3FlashRMSNorm
# else:
except ImportError as error:
logger.warning(
f"Flash Attention or Flash Attention Submodules not found, consider installing for better performance: {error}."
)
if not _flash_supports_window_size:
logger.warning(
"This version of flash does not support window size. Please use `attn_implementation='eager'` or upgrade flash-attn library."
)
swiglu = None
Phi3FlashRMSNorm = None
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
_CONFIG_FOR_DOC = "Phi3Config"
@ -103,9 +84,6 @@ class Phi3RMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype)
PHI3_NORM_CLASS = Phi3RMSNorm if Phi3FlashRMSNorm is None else Phi3FlashRMSNorm
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@ -271,13 +249,8 @@ class Phi3MLP(nn.Module):
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
y = self.gate_up_proj(hidden_states)
# Special case for SwiGLU
if self.config.hidden_act == "silu" and swiglu is not None:
gate, y = y.chunk(2, dim=-1)
y = swiglu(gate, y)
else:
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
return self.down_proj(y)
@ -851,11 +824,11 @@ class Phi3DecoderLayer(nn.Module):
self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = Phi3MLP(config)
self.input_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
self.post_attention_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
@ -1059,7 +1032,7 @@ class Phi3Model(Phi3PreTrainedModel):
[Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing