fix(phi3): Removes additional flash-attention usage, .e.g, swiglu and rmsnorm.
This commit is contained in:
parent
508ec8ef31
commit
56e6464f1a
|
@ -40,6 +40,7 @@ from ...utils import (
|
|||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
|
@ -47,33 +48,13 @@ from ...utils import (
|
|||
from .configuration_phi3 import Phi3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
|
||||
# if is_flash_attn_2_available():
|
||||
_flash_supports_window_size = False
|
||||
try:
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
||||
|
||||
if not _flash_supports_window_size:
|
||||
raise ValueError("Please update flash-attention to support window size.")
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from flash_attn.ops.activations import swiglu
|
||||
from flash_attn.ops.rms_norm import RMSNorm as Phi3FlashRMSNorm
|
||||
# else:
|
||||
except ImportError as error:
|
||||
logger.warning(
|
||||
f"Flash Attention or Flash Attention Submodules not found, consider installing for better performance: {error}."
|
||||
)
|
||||
if not _flash_supports_window_size:
|
||||
logger.warning(
|
||||
"This version of flash does not support window size. Please use `attn_implementation='eager'` or upgrade flash-attn library."
|
||||
)
|
||||
swiglu = None
|
||||
Phi3FlashRMSNorm = None
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
|
||||
_CONFIG_FOR_DOC = "Phi3Config"
|
||||
|
@ -103,9 +84,6 @@ class Phi3RMSNorm(nn.Module):
|
|||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
PHI3_NORM_CLASS = Phi3RMSNorm if Phi3FlashRMSNorm is None else Phi3FlashRMSNorm
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
|
@ -271,13 +249,8 @@ class Phi3MLP(nn.Module):
|
|||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
y = self.gate_up_proj(hidden_states)
|
||||
|
||||
# Special case for SwiGLU
|
||||
if self.config.hidden_act == "silu" and swiglu is not None:
|
||||
gate, y = y.chunk(2, dim=-1)
|
||||
y = swiglu(gate, y)
|
||||
else:
|
||||
gate, y = y.chunk(2, dim=-1)
|
||||
y = y * self.activation_fn(gate)
|
||||
gate, y = y.chunk(2, dim=-1)
|
||||
y = y * self.activation_fn(gate)
|
||||
|
||||
return self.down_proj(y)
|
||||
|
||||
|
@ -851,11 +824,11 @@ class Phi3DecoderLayer(nn.Module):
|
|||
self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = Phi3MLP(config)
|
||||
self.input_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.post_attention_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1059,7 +1032,7 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||
[Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
|
|
Loading…
Reference in New Issue