diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index da33488aa7..e384b47b3f 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -40,6 +40,7 @@ from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -47,33 +48,13 @@ from ...utils import ( from .configuration_phi3 import Phi3Config -logger = logging.get_logger(__name__) - -# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements -# if is_flash_attn_2_available(): -_flash_supports_window_size = False -try: +if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - if not _flash_supports_window_size: - raise ValueError("Please update flash-attention to support window size.") - - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from flash_attn.ops.activations import swiglu - from flash_attn.ops.rms_norm import RMSNorm as Phi3FlashRMSNorm -# else: -except ImportError as error: - logger.warning( - f"Flash Attention or Flash Attention Submodules not found, consider installing for better performance: {error}." - ) - if not _flash_supports_window_size: - logger.warning( - "This version of flash does not support window size. Please use `attn_implementation='eager'` or upgrade flash-attn library." - ) - swiglu = None - Phi3FlashRMSNorm = None +logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" _CONFIG_FOR_DOC = "Phi3Config" @@ -103,9 +84,6 @@ class Phi3RMSNorm(nn.Module): return self.weight * hidden_states.to(input_dtype) -PHI3_NORM_CLASS = Phi3RMSNorm if Phi3FlashRMSNorm is None else Phi3FlashRMSNorm - - # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -271,13 +249,8 @@ class Phi3MLP(nn.Module): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: y = self.gate_up_proj(hidden_states) - # Special case for SwiGLU - if self.config.hidden_act == "silu" and swiglu is not None: - gate, y = y.chunk(2, dim=-1) - y = swiglu(gate, y) - else: - gate, y = y.chunk(2, dim=-1) - y = y * self.activation_fn(gate) + gate, y = y.chunk(2, dim=-1) + y = y * self.activation_fn(gate) return self.down_proj(y) @@ -851,11 +824,11 @@ class Phi3DecoderLayer(nn.Module): self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.mlp = Phi3MLP(config) - self.input_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) - self.post_attention_layernorm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -1059,7 +1032,7 @@ class Phi3Model(Phi3PreTrainedModel): [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = config._attn_implementation - self.norm = PHI3_NORM_CLASS(config.hidden_size, eps=config.rms_norm_eps) + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing