[Phi] Extend implementation to use GQA/MQA. (#28163)
* chore(phi): Updates configuration_phi with missing keys. * chore(phi): Adds first draft of combined modeling_phi. * fix(phi): Fixes according to latest review. * fix(phi): Removes pad_vocab_size_multiple to prevent inconsistencies. * fix(phi): Fixes unit and integration tests. * fix(phi): Ensures that everything works with microsoft/phi-1 for first integration. * fix(phi): Fixes output of docstring generation. * fix(phi): Fixes according to latest review. * fix(phi): Fixes according to latest review. * fix(tests): Re-enables Phi-1.5 test. * fix(phi): Fixes attention overflow on PhiAttention (for Phi-2). * fix(phi): Improves how queries and keys are upcast. * fix(phi): Small updates on latest changes.
This commit is contained in:
parent
d560637885
commit
5509058561
|
@ -23,8 +23,9 @@ from ...utils import logging
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"susnato/phi-1_dev": "https://huggingface.co/susnato/phi-1_dev/resolve/main/config.json",
|
"microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
|
||||||
"susnato/phi-1_5_dev": "https://huggingface.co/susnato/phi-1_5_dev/resolve/main/config.json",
|
"microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
|
||||||
|
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ class PhiConfig(PretrainedConfig):
|
||||||
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
|
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
|
||||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
defaults will yield a similar configuration to that of the Phi
|
defaults will yield a similar configuration to that of the Phi
|
||||||
[susnato/phi-1_dev](https://huggingface.co/susnato/phi-1_dev).
|
[microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
@ -50,6 +51,14 @@ class PhiConfig(PretrainedConfig):
|
||||||
Number of hidden layers in the Transformer decoder.
|
Number of hidden layers in the Transformer decoder.
|
||||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
Number of attention heads for each attention layer in the Transformer decoder.
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||||
Dropout probability for mlp outputs.
|
Dropout probability for mlp outputs.
|
||||||
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
||||||
|
@ -83,7 +92,7 @@ class PhiConfig(PretrainedConfig):
|
||||||
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
||||||
Percentage of the query and keys which will have rotary embedding.
|
Percentage of the query and keys which will have rotary embedding.
|
||||||
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to normalize the Queries and Keys after projecting the hidden states
|
Whether or not to normalize the Queries and Keys after projecting the hidden states.
|
||||||
bos_token_id (`int`, *optional*, defaults to 1):
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
Denotes beginning of sequences token id.
|
Denotes beginning of sequences token id.
|
||||||
eos_token_id (`int`, *optional*, defaults to 2):
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
@ -95,7 +104,7 @@ class PhiConfig(PretrainedConfig):
|
||||||
>>> from transformers import PhiModel, PhiConfig
|
>>> from transformers import PhiModel, PhiConfig
|
||||||
|
|
||||||
>>> # Initializing a Phi-1 style configuration
|
>>> # Initializing a Phi-1 style configuration
|
||||||
>>> configuration = PhiConfig.from_pretrained("susnato/phi-1_dev")
|
>>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
|
||||||
|
|
||||||
>>> # Initializing a model from the configuration
|
>>> # Initializing a model from the configuration
|
||||||
>>> model = PhiModel(configuration)
|
>>> model = PhiModel(configuration)
|
||||||
|
@ -114,6 +123,7 @@ class PhiConfig(PretrainedConfig):
|
||||||
intermediate_size=8192,
|
intermediate_size=8192,
|
||||||
num_hidden_layers=24,
|
num_hidden_layers=24,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
resid_pdrop=0.0,
|
resid_pdrop=0.0,
|
||||||
embd_pdrop=0.0,
|
embd_pdrop=0.0,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
|
@ -136,6 +146,11 @@ class PhiConfig(PretrainedConfig):
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
self.resid_pdrop = resid_pdrop
|
self.resid_pdrop = resid_pdrop
|
||||||
self.embd_pdrop = embd_pdrop
|
self.embd_pdrop = embd_pdrop
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
|
|
|
@ -54,12 +54,13 @@ if is_flash_attn_2_available():
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "susnato/phi-1_dev"
|
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
||||||
_CONFIG_FOR_DOC = "PhiConfig"
|
_CONFIG_FOR_DOC = "PhiConfig"
|
||||||
|
|
||||||
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"susnato/phi-1_dev",
|
"microsoft/phi-1",
|
||||||
"susnato/phi-1_5_dev",
|
"microsoft/phi-1_5",
|
||||||
|
"microsoft/phi-2",
|
||||||
# See all Phi models at https://huggingface.co/models?filter=phi
|
# See all Phi models at https://huggingface.co/models?filter=phi
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -214,7 +215,19 @@ class PhiMLP(nn.Module):
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention with Persimmon->Phi,persimmon->phi
|
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
class PhiAttention(nn.Module):
|
class PhiAttention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
@ -229,9 +242,12 @@ class PhiAttention(nn.Module):
|
||||||
"when creating this class."
|
"when creating this class."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.partial_rotary_factor = config.partial_rotary_factor
|
self.partial_rotary_factor = config.partial_rotary_factor
|
||||||
|
@ -242,10 +258,13 @@ class PhiAttention(nn.Module):
|
||||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
f" and `num_heads`: {self.num_heads})."
|
f" and `num_heads`: {self.num_heads})."
|
||||||
)
|
)
|
||||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
|
||||||
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
|
||||||
self.qk_layernorm = config.qk_layernorm
|
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
||||||
|
|
||||||
|
self.qk_layernorm = config.qk_layernorm
|
||||||
if self.qk_layernorm:
|
if self.qk_layernorm:
|
||||||
self.q_layernorm = nn.LayerNorm(
|
self.q_layernorm = nn.LayerNorm(
|
||||||
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
||||||
|
@ -253,7 +272,7 @@ class PhiAttention(nn.Module):
|
||||||
self.k_layernorm = nn.LayerNorm(
|
self.k_layernorm = nn.LayerNorm(
|
||||||
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
||||||
)
|
)
|
||||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
||||||
self._init_rope()
|
self._init_rope()
|
||||||
|
|
||||||
def _init_rope(self):
|
def _init_rope(self):
|
||||||
|
@ -283,23 +302,6 @@ class PhiAttention(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
|
||||||
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads
|
|
||||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
|
||||||
storage as `fused_qkv`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
value: [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
|
||||||
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
|
||||||
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -311,20 +313,17 @@ class PhiAttention(nn.Module):
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# [batch_size, seq_length, 3 x hidden_size]
|
query_states = self.q_proj(hidden_states)
|
||||||
fused_qkv = self.query_key_value(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
|
|
||||||
|
|
||||||
if self.qk_layernorm:
|
if self.qk_layernorm:
|
||||||
query_states = self.q_layernorm(query_states)
|
query_states = self.q_layernorm(query_states)
|
||||||
key_states = self.k_layernorm(key_states)
|
key_states = self.k_layernorm(key_states)
|
||||||
|
|
||||||
# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
query_states = query_states.transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
value_states = value_states.transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -354,11 +353,16 @@ class PhiAttention(nn.Module):
|
||||||
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# Specific to RoPE models with partial rotation
|
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -374,8 +378,8 @@ class PhiAttention(nn.Module):
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
||||||
attn_weights = self.attention_dropout(attn_weights)
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
@ -398,9 +402,9 @@ class PhiAttention(nn.Module):
|
||||||
|
|
||||||
class PhiFlashAttention2(PhiAttention):
|
class PhiFlashAttention2(PhiAttention):
|
||||||
"""
|
"""
|
||||||
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays untouched.
|
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
|
||||||
The only required change would be on the forward pass where it needs to correctly call the public API of flash
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||||
attention and deal with padding tokens in case the input contains any of them.
|
flash attention and deal with padding tokens in case the input contains any of them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||||
|
@ -415,11 +419,12 @@ class PhiFlashAttention2(PhiAttention):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# PhiFlashAttention2 attention does not support output_attentions
|
# PhiFlashAttention2 attention does not support output_attentions
|
||||||
|
|
||||||
|
@ -427,20 +432,20 @@ class PhiFlashAttention2(PhiAttention):
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# [batch_size, seq_length, 3 x hidden_size]
|
query_states = self.q_proj(hidden_states)
|
||||||
fused_qkv = self.query_key_value(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
|
||||||
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
|
|
||||||
|
|
||||||
if self.qk_layernorm:
|
if self.qk_layernorm:
|
||||||
query_states = self.q_layernorm(query_states)
|
query_states = self.q_layernorm(query_states)
|
||||||
key_states = self.k_layernorm(key_states)
|
key_states = self.k_layernorm(key_states)
|
||||||
|
|
||||||
# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
|
# Flash attention requires the input to have the shape
|
||||||
query_states = query_states.transpose(1, 2)
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
value_states = value_states.transpose(1, 2)
|
# therefore we just need to keep the original shape
|
||||||
key_states = key_states.transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -467,15 +472,13 @@ class PhiFlashAttention2(PhiAttention):
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
tgt_len = key_states.shape[2]
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
# Flash attention requires the input to have the shape
|
attn_dropout = self.attention_dropout if self.training else 0.0
|
||||||
# batch_size x seq_length x head_dim x hidden_dim
|
|
||||||
query_states = query_states.transpose(1, 2).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
|
||||||
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
@ -506,7 +509,7 @@ class PhiFlashAttention2(PhiAttention):
|
||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
|
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
attn_output = self.dense(attn_output)
|
attn_output = self.dense(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
|
@ -708,6 +711,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||||
config_class = PhiConfig
|
config_class = PhiConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = ["PhiDecoderLayer"]
|
||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_cache_class = True
|
_supports_cache_class = True
|
||||||
|
@ -745,7 +749,7 @@ PHI_INPUTS_DOCSTRING = r"""
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
|
|
||||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||||
|
@ -852,13 +856,13 @@ class PhiModel(PhiPreTrainedModel):
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape[:2]
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
@ -1020,8 +1024,8 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoTokenizer, PhiForCausalLM
|
>>> from transformers import AutoTokenizer, PhiForCausalLM
|
||||||
|
|
||||||
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev")
|
>>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
||||||
|
|
||||||
>>> prompt = "This is an example script ."
|
>>> prompt = "This is an example script ."
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
@ -1029,7 +1033,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
||||||
>>> # Generate
|
>>> # Generate
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
'This is an example script .py file that uses the `os` module to create a new directory and write some text to it.\n\n``'
|
'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|
|
@ -365,18 +365,18 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
|
||||||
def test_flash_attn_2_generate_padding_right(self):
|
def test_flash_attn_2_generate_padding_right(self):
|
||||||
"""
|
"""
|
||||||
Overwritting the common test as the test is flaky on tiny models
|
Overwritting the common test as the test is flaky on tiny models
|
||||||
"""
|
"""
|
||||||
model = PhiForCausalLM.from_pretrained(
|
model = PhiForCausalLM.from_pretrained(
|
||||||
"susnato/phi-1_5_dev",
|
"microsoft/phi-1",
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
device_map={"": 0},
|
device_map={"": 0},
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
|
||||||
|
|
||||||
texts = ["hi", "Hello this is a very long sentence"]
|
texts = ["hi", "Hello this is a very long sentence"]
|
||||||
|
|
||||||
|
@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||||
output_native = tokenizer.batch_decode(output_native)
|
output_native = tokenizer.batch_decode(output_native)
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained(
|
model = PhiForCausalLM.from_pretrained(
|
||||||
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
"microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
||||||
)
|
)
|
||||||
|
|
||||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
@ -408,7 +408,7 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained("susnato/phi-1_dev").to(torch_device)
|
model = PhiForCausalLM.from_pretrained("microsoft/phi-1").to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
output = model(**input_ids).logits
|
output = model(**input_ids).logits
|
||||||
|
@ -424,7 +424,7 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev").to(torch_device)
|
model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5").to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
output = model(**input_ids).logits
|
output = model(**input_ids).logits
|
||||||
|
@ -440,7 +440,7 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device)
|
model = PhiForCausalLM.from_pretrained("microsoft/phi-2").to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
output = model(**input_ids).logits
|
output = model(**input_ids).logits
|
||||||
|
@ -450,8 +450,8 @@ class PhiIntegrationTest(unittest.TestCase):
|
||||||
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3))
|
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3))
|
||||||
|
|
||||||
def test_phi_2_generation(self):
|
def test_phi_2_generation(self):
|
||||||
model = PhiForCausalLM.from_pretrained("susnato/phi-2")
|
model = PhiForCausalLM.from_pretrained("microsoft/phi-2")
|
||||||
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
|
||||||
|
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
"Can you help me write a formal email to a potential business partner proposing a joint venture?",
|
"Can you help me write a formal email to a potential business partner proposing a joint venture?",
|
||||||
|
|
Loading…
Reference in New Issue