[Phi] Extend implementation to use GQA/MQA. (#28163)

* chore(phi): Updates configuration_phi with missing keys.

* chore(phi): Adds first draft of combined modeling_phi.

* fix(phi): Fixes according to latest review.

* fix(phi): Removes pad_vocab_size_multiple to prevent inconsistencies.

* fix(phi): Fixes unit and integration tests.

* fix(phi): Ensures that everything works with microsoft/phi-1 for first integration.

* fix(phi): Fixes output of docstring generation.

* fix(phi): Fixes according to latest review.

* fix(phi): Fixes according to latest review.

* fix(tests): Re-enables Phi-1.5 test.

* fix(phi): Fixes attention overflow on PhiAttention (for Phi-2).

* fix(phi): Improves how queries and keys are upcast.

* fix(phi): Small updates on latest changes.
This commit is contained in:
Gustavo de Rosa 2024-01-11 11:58:02 -03:00 committed by GitHub
parent d560637885
commit 5509058561
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 82 deletions

View File

@ -23,8 +23,9 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = { PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"susnato/phi-1_dev": "https://huggingface.co/susnato/phi-1_dev/resolve/main/config.json", "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
"susnato/phi-1_5_dev": "https://huggingface.co/susnato/phi-1_5_dev/resolve/main/config.json", "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
} }
@ -33,7 +34,7 @@ class PhiConfig(PretrainedConfig):
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Phi defaults will yield a similar configuration to that of the Phi
[susnato/phi-1_dev](https://huggingface.co/susnato/phi-1_dev). [microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
@ -50,6 +51,14 @@ class PhiConfig(PretrainedConfig):
Number of hidden layers in the Transformer decoder. Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32): num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
resid_pdrop (`float`, *optional*, defaults to 0.0): resid_pdrop (`float`, *optional*, defaults to 0.0):
Dropout probability for mlp outputs. Dropout probability for mlp outputs.
embd_pdrop (`int`, *optional*, defaults to 0.0): embd_pdrop (`int`, *optional*, defaults to 0.0):
@ -83,7 +92,7 @@ class PhiConfig(PretrainedConfig):
partial_rotary_factor (`float`, *optional*, defaults to 0.5): partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding. Percentage of the query and keys which will have rotary embedding.
qk_layernorm (`bool`, *optional*, defaults to `False`): qk_layernorm (`bool`, *optional*, defaults to `False`):
Whether or not to normalize the Queries and Keys after projecting the hidden states Whether or not to normalize the Queries and Keys after projecting the hidden states.
bos_token_id (`int`, *optional*, defaults to 1): bos_token_id (`int`, *optional*, defaults to 1):
Denotes beginning of sequences token id. Denotes beginning of sequences token id.
eos_token_id (`int`, *optional*, defaults to 2): eos_token_id (`int`, *optional*, defaults to 2):
@ -95,7 +104,7 @@ class PhiConfig(PretrainedConfig):
>>> from transformers import PhiModel, PhiConfig >>> from transformers import PhiModel, PhiConfig
>>> # Initializing a Phi-1 style configuration >>> # Initializing a Phi-1 style configuration
>>> configuration = PhiConfig.from_pretrained("susnato/phi-1_dev") >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
>>> # Initializing a model from the configuration >>> # Initializing a model from the configuration
>>> model = PhiModel(configuration) >>> model = PhiModel(configuration)
@ -114,6 +123,7 @@ class PhiConfig(PretrainedConfig):
intermediate_size=8192, intermediate_size=8192,
num_hidden_layers=24, num_hidden_layers=24,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=None,
resid_pdrop=0.0, resid_pdrop=0.0,
embd_pdrop=0.0, embd_pdrop=0.0,
attention_dropout=0.0, attention_dropout=0.0,
@ -136,6 +146,11 @@ class PhiConfig(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout

View File

@ -54,12 +54,13 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "susnato/phi-1_dev" _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig" _CONFIG_FOR_DOC = "PhiConfig"
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [ PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
"susnato/phi-1_dev", "microsoft/phi-1",
"susnato/phi-1_5_dev", "microsoft/phi-1_5",
"microsoft/phi-2",
# See all Phi models at https://huggingface.co/models?filter=phi # See all Phi models at https://huggingface.co/models?filter=phi
] ]
@ -214,7 +215,19 @@ class PhiMLP(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention with Persimmon->Phi,persimmon->phi # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
@ -229,9 +242,12 @@ class PhiAttention(nn.Module):
"when creating this class." "when creating this class."
) )
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor self.partial_rotary_factor = config.partial_rotary_factor
@ -242,10 +258,13 @@ class PhiAttention(nn.Module):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})." f" and `num_heads`: {self.num_heads})."
) )
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
self.qk_layernorm = config.qk_layernorm
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm: if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm( self.q_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
@ -253,7 +272,7 @@ class PhiAttention(nn.Module):
self.k_layernorm = nn.LayerNorm( self.k_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
) )
self.attention_dropout = nn.Dropout(config.attention_dropout)
self._init_rope() self._init_rope()
def _init_rope(self): def _init_rope(self):
@ -283,23 +302,6 @@ class PhiAttention(nn.Module):
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -311,20 +313,17 @@ class PhiAttention(nn.Module):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# [batch_size, seq_length, 3 x hidden_size] query_states = self.q_proj(hidden_states)
fused_qkv = self.query_key_value(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
if self.qk_layernorm: if self.qk_layernorm:
query_states = self.q_layernorm(query_states) query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states) key_states = self.k_layernorm(key_states)
# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = key_states.transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
@ -354,11 +353,16 @@ class PhiAttention(nn.Module):
key_states = torch.cat((key_rot, key_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None: if past_key_value is not None:
# Specific to RoPE models with partial rotation
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
attn_weights = torch.matmul(
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError( raise ValueError(
@ -374,8 +378,8 @@ class PhiAttention(nn.Module):
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
attn_weights = self.attention_dropout(attn_weights) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
@ -398,9 +402,9 @@ class PhiAttention(nn.Module):
class PhiFlashAttention2(PhiAttention): class PhiFlashAttention2(PhiAttention):
""" """
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays untouched. Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
The only required change would be on the forward pass where it needs to correctly call the public API of flash untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
attention and deal with padding tokens in case the input contains any of them. flash attention and deal with padding tokens in case the input contains any of them.
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
@ -415,11 +419,12 @@ class PhiFlashAttention2(PhiAttention):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# PhiFlashAttention2 attention does not support output_attentions # PhiFlashAttention2 attention does not support output_attentions
@ -427,20 +432,20 @@ class PhiFlashAttention2(PhiAttention):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
# [batch_size, seq_length, 3 x hidden_size] query_states = self.q_proj(hidden_states)
fused_qkv = self.query_key_value(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_states, key_states, value_states) = self._split_heads(fused_qkv)
if self.qk_layernorm: if self.qk_layernorm:
query_states = self.q_layernorm(query_states) query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states) key_states = self.k_layernorm(key_states)
# [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] # Flash attention requires the input to have the shape
query_states = query_states.transpose(1, 2) # batch_size x seq_length x head_dim x hidden_dim
value_states = value_states.transpose(1, 2) # therefore we just need to keep the original shape
key_states = key_states.transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
@ -467,15 +472,13 @@ class PhiFlashAttention2(PhiAttention):
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
tgt_len = key_states.shape[2] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Flash attention requires the input to have the shape attn_dropout = self.attention_dropout if self.training else 0.0
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.transpose(1, 2).view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
attn_dropout = self.config.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
@ -506,7 +509,7 @@ class PhiFlashAttention2(PhiAttention):
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
) )
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.dense(attn_output) attn_output = self.dense(attn_output)
if not output_attentions: if not output_attentions:
@ -708,6 +711,7 @@ class PhiPreTrainedModel(PreTrainedModel):
config_class = PhiConfig config_class = PhiConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True _supports_cache_class = True
@ -745,7 +749,7 @@ PHI_INPUTS_DOCSTRING = r"""
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`). `past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@ -852,13 +856,13 @@ class PhiModel(PhiPreTrainedModel):
# retrieve input_ids and inputs_embeds # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0 past_key_values_length = 0
@ -1020,8 +1024,8 @@ class PhiForCausalLM(PhiPreTrainedModel):
```python ```python
>>> from transformers import AutoTokenizer, PhiForCausalLM >>> from transformers import AutoTokenizer, PhiForCausalLM
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev") >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev") >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
>>> prompt = "This is an example script ." >>> prompt = "This is an example script ."
>>> inputs = tokenizer(prompt, return_tensors="pt") >>> inputs = tokenizer(prompt, return_tensors="pt")
@ -1029,7 +1033,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
>>> # Generate >>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .py file that uses the `os` module to create a new directory and write some text to it.\n\n``' 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
```""" ```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

View File

@ -365,18 +365,18 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
@require_bitsandbytes @require_bitsandbytes
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@slow @slow
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
""" """
Overwritting the common test as the test is flaky on tiny models Overwritting the common test as the test is flaky on tiny models
""" """
model = PhiForCausalLM.from_pretrained( model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", "microsoft/phi-1",
load_in_4bit=True, load_in_4bit=True,
device_map={"": 0}, device_map={"": 0},
) )
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev") tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
texts = ["hi", "Hello this is a very long sentence"] texts = ["hi", "Hello this is a very long sentence"]
@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_native = tokenizer.batch_decode(output_native) output_native = tokenizer.batch_decode(output_native)
model = PhiForCausalLM.from_pretrained( model = PhiForCausalLM.from_pretrained(
"susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
) )
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@ -408,7 +408,7 @@ class PhiIntegrationTest(unittest.TestCase):
) )
} }
model = PhiForCausalLM.from_pretrained("susnato/phi-1_dev").to(torch_device) model = PhiForCausalLM.from_pretrained("microsoft/phi-1").to(torch_device)
model.eval() model.eval()
output = model(**input_ids).logits output = model(**input_ids).logits
@ -424,7 +424,7 @@ class PhiIntegrationTest(unittest.TestCase):
) )
} }
model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev").to(torch_device) model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5").to(torch_device)
model.eval() model.eval()
output = model(**input_ids).logits output = model(**input_ids).logits
@ -440,7 +440,7 @@ class PhiIntegrationTest(unittest.TestCase):
) )
} }
model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device) model = PhiForCausalLM.from_pretrained("microsoft/phi-2").to(torch_device)
model.eval() model.eval()
output = model(**input_ids).logits output = model(**input_ids).logits
@ -450,8 +450,8 @@ class PhiIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3)) self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3))
def test_phi_2_generation(self): def test_phi_2_generation(self):
model = PhiForCausalLM.from_pretrained("susnato/phi-2") model = PhiForCausalLM.from_pretrained("microsoft/phi-2")
tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2") tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
inputs = tokenizer( inputs = tokenizer(
"Can you help me write a formal email to a potential business partner proposing a joint venture?", "Can you help me write a formal email to a potential business partner proposing a joint venture?",