Compare commits

...

20 Commits
main ... mymain

Author SHA1 Message Date
ydshieh 6f25ef40fd fix 2024-04-24 16:50:49 +02:00
ydshieh d4a3ec26c4 run 2024-04-24 16:44:02 +02:00
Gustavo de Rosa aeb6ae7ebe fix(phi3): Splits inv_freq calculation in two lines. 2024-04-24 06:29:16 -07:00
Gustavo de Rosa 2abcd4dec3 fix(phi3): Adds last suggestions to modeling file. 2024-04-24 06:24:17 -07:00
Gustavo de Rosa 06cd06d29f Merge remote-tracking branch 'upstream/main' into main 2024-04-24 06:08:32 -07:00
Gustavo de Rosa 9427419dd9 fix(phi3): Fixes inv_freq not being re-computed for extended RoPE. 2024-04-23 17:40:33 -07:00
Gustavo de Rosa 817fec7bfa fix(phi3): Improves how rotary embedding classes are defined. 2024-04-23 14:33:20 -07:00
Gustavo de Rosa 4cfa767de7 fix(phi3): Uses gemma rotary embedding to support torch.compile. 2024-04-23 13:46:29 -07:00
Gustavo de Rosa 3a24a1d4d2 fix(phi3): Uses up_states instead of y in Phi3MLP. 2024-04-23 13:40:45 -07:00
Gustavo de Rosa d5aed89bd3 fix(phi3): Improves according first batch of reviews. 2024-04-23 13:37:51 -07:00
Gustavo de Rosa c442d06484 fix(phi3): Adds support for Su and Yarn embeddings. 2024-04-23 10:40:09 -07:00
Gustavo de Rosa 92d83790af fix(phi3): Fixes docstring typos. 2024-04-23 10:03:32 -07:00
Gustavo de Rosa 9bc1f1f137 fix(phi3): Fixes incorrect docstrings. 2024-04-23 09:12:07 -07:00
Gustavo de Rosa 56e6464f1a fix(phi3): Removes additional flash-attention usage, .e.g, swiglu and rmsnorm. 2024-04-23 08:39:27 -07:00
Gustavo de Rosa 508ec8ef31 chore(tests): Adds integration tests for Phi-3. 2024-04-23 08:21:03 -07:00
Gustavo de Rosa b62e6f3eed fix(tests): Fixes style of phi-3 test file. 2024-04-23 06:38:52 -07:00
Gustavo de Rosa 912edf15e4 fix(phi3): Fixes unit tests. 2024-04-23 06:26:31 -07:00
Gustavo de Rosa e0b68151ad fix(root): Ensures files are consistent. 2024-04-23 05:44:53 -07:00
Gustavo de Rosa 416eaa4160 fix(root): Fixes Phi-3 missing on readme. 2024-04-23 05:10:33 -07:00
Gustavo de Rosa c1e38b0bcf chore(root): Initial commit of Phi-3 files. 2024-04-23 05:08:54 -07:00
16 changed files with 2515 additions and 2 deletions

View File

@ -460,6 +460,8 @@
title: Persimmon
- local: model_doc/phi
title: Phi
- local: model_doc/phi3
title: Phi-3
- local: model_doc/phobert
title: PhoBERT
- local: model_doc/plbart

View File

@ -236,6 +236,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Perceiver](model_doc/perceiver) | ✅ | ❌ | ❌ |
| [Persimmon](model_doc/persimmon) | ✅ | ❌ | ❌ |
| [Phi](model_doc/phi) | ✅ | ❌ | ❌ |
| [Phi3](model_doc/phi3) | ✅ | ❌ | ❌ |
| [PhoBERT](model_doc/phobert) | ✅ | ✅ | ✅ |
| [Pix2Struct](model_doc/pix2struct) | ✅ | ❌ | ❌ |
| [PLBart](model_doc/plbart) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,92 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Phi-3
## Overview
The Phi-3 model was proposed in [Phi-3 Technical Report: A Highly Capable Language Model Locally on Your Phone](https://arxiv.org/abs/2404.14219) by Microsoft.
### Summary
The abstract from the Phi-3 paper is the following:
We introduce phi-3-mini, a 3.8 billion parameter language model trained on 3.3 trillion tokens, whose overall performance, as measured by both academic benchmarks and internal testing, rivals that of models such as Mixtral 8x7B and GPT-3.5 (e.g., phi-3-mini achieves 69% on MMLU and 8.38 on MT-bench), despite being small enough to be deployed on a phone. The innovation lies entirely in our dataset for training, a scaled-up version of the one used for phi-2, composed of heavily filtered web data and synthetic data. The model is also further aligned for robustness, safety, and chat format. We also provide some initial parameter-scaling results with a 7B and 14B models trained for 4.8T tokens, called phi-3-small and phi-3-medium, both significantly more capable than phi-3-mini (e.g., respectively 75% and 78% on MMLU, and 8.7 and 8.9 on MT-bench).
The original code for Phi-3 can be found [here](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
## Usage tips
- This model is very similar to `Llama` with the main difference of [`Phi3SuScaledRotaryEmbedding`] and [`Phi3YarnScaledRotaryEmbedding`], where they are used to extend the context of the rotary embeddings. The query, key and values are fused, and the MLP's up and gate projection layers are also fused.
- The tokenizer used for this model is identical to the [`LlamaTokenizer`], with the exception of additional tokens.
## How to use Phi-3
<Tip warning={true}>
Phi-3 has been integrated in the development version (4.40.0.dev) of `transformers`. Until the official version is released through `pip`, ensure that you are doing one of the following:
* When loading the model, ensure that `trust_remote_code=True` is passed as an argument of the `from_pretrained()` function.
* Update your local `transformers` to the development version: `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers`. The previous command is an alternative to cloning and installing from the source.
</Tip>
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
>>> messages = [{"role": "system", "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user."},{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}]
>>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
>>> outputs = model.generate(inputs, max_new_tokens=32)
>>> text = tokenizer.batch_decode(outputs)[0]
>>> print(text)
<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit
```
## Phi3Config
[[autodoc]] Phi3Config
<frameworkcontent>
<pt>
## Phi3Model
[[autodoc]] Phi3Model
- forward
## Phi3ForCausalLM
[[autodoc]] Phi3ForCausalLM
- forward
- generate
## Phi3ForSequenceClassification
[[autodoc]] Phi3ForSequenceClassification
- forward
## Phi3ForTokenClassification
[[autodoc]] Phi3ForTokenClassification
- forward
</pt>
</frameworkcontent>

View File

@ -65,6 +65,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)

View File

@ -709,6 +709,7 @@ _import_structure = {
],
"models.persimmon": ["PERSIMMON_PRETRAINED_CONFIG_ARCHIVE_MAP", "PersimmonConfig"],
"models.phi": ["PHI_PRETRAINED_CONFIG_ARCHIVE_MAP", "PhiConfig"],
"models.phi3": ["PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP", "Phi3Config"],
"models.phobert": ["PhobertTokenizer"],
"models.pix2struct": [
"PIX2STRUCT_PRETRAINED_CONFIG_ARCHIVE_MAP",
@ -3057,6 +3058,16 @@ else:
"PhiPreTrainedModel",
]
)
_import_structure["models.phi3"].extend(
[
"PHI3_PRETRAINED_MODEL_ARCHIVE_LIST",
"Phi3ForCausalLM",
"Phi3ForSequenceClassification",
"Phi3ForTokenClassification",
"Phi3Model",
"Phi3PreTrainedModel",
]
)
_import_structure["models.pix2struct"].extend(
[
"PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -5669,6 +5680,7 @@ if TYPE_CHECKING:
PersimmonConfig,
)
from .models.phi import PHI_PRETRAINED_CONFIG_ARCHIVE_MAP, PhiConfig
from .models.phi3 import PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP, Phi3Config
from .models.phobert import PhobertTokenizer
from .models.pix2struct import (
PIX2STRUCT_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -7715,6 +7727,14 @@ if TYPE_CHECKING:
PhiModel,
PhiPreTrainedModel,
)
from .models.phi3 import (
PHI3_PRETRAINED_MODEL_ARCHIVE_LIST,
Phi3ForCausalLM,
Phi3ForSequenceClassification,
Phi3ForTokenClassification,
Phi3Model,
Phi3PreTrainedModel,
)
from .models.pix2struct import (
PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST,
Pix2StructForConditionalGeneration,

View File

@ -179,6 +179,7 @@ from . import (
perceiver,
persimmon,
phi,
phi3,
phobert,
pix2struct,
plbart,

View File

@ -191,6 +191,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("perceiver", "PerceiverConfig"),
("persimmon", "PersimmonConfig"),
("phi", "PhiConfig"),
("phi3", "Phi3Config"),
("pix2struct", "Pix2StructConfig"),
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
@ -471,6 +472,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("perceiver", "Perceiver"),
("persimmon", "Persimmon"),
("phi", "Phi"),
("phi3", "Phi3"),
("phobert", "PhoBERT"),
("pix2struct", "Pix2Struct"),
("plbart", "PLBart"),

View File

@ -180,6 +180,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("perceiver", "PerceiverModel"),
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
("plbart", "PLBartModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
@ -474,6 +475,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("pegasus", "PegasusForCausalLM"),
("persimmon", "PersimmonForCausalLM"),
("phi", "PhiForCausalLM"),
("phi3", "Phi3ForCausalLM"),
("plbart", "PLBartForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
@ -884,6 +886,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("perceiver", "PerceiverForSequenceClassification"),
("persimmon", "PersimmonForSequenceClassification"),
("phi", "PhiForSequenceClassification"),
("phi3", "Phi3ForSequenceClassification"),
("plbart", "PLBartForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"),
("qwen2", "Qwen2ForSequenceClassification"),
@ -1049,6 +1052,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("nezha", "NezhaForTokenClassification"),
("nystromformer", "NystromformerForTokenClassification"),
("phi", "PhiForTokenClassification"),
("phi3", "Phi3ForTokenClassification"),
("qdqbert", "QDQBertForTokenClassification"),
("rembert", "RemBertForTokenClassification"),
("roberta", "RobertaForTokenClassification"),

View File

@ -353,6 +353,7 @@ else:
),
),
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),

View File

@ -0,0 +1,69 @@
# Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_phi3": ["PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP", "Phi3Config"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_phi3"] = [
"PHI3_PRETRAINED_MODEL_ARCHIVE_LIST",
"Phi3PreTrainedModel",
"Phi3Model",
"Phi3ForCausalLM",
"Phi3ForSequenceClassification",
"Phi3ForTokenClassification",
]
if TYPE_CHECKING:
from .configuration_phi3 import PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP, Phi3Config
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_phi3 import (
PHI3_PRETRAINED_MODEL_ARCHIVE_LIST,
Phi3ForCausalLM,
Phi3ForSequenceClassification,
Phi3ForTokenClassification,
Phi3Model,
Phi3PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,213 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Phi-3 model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json",
"microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json",
}
class Phi3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32064):
Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Phi3Model`].
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
resid_pdrop (`float`, *optional*, defaults to 0.0):
Dropout probability for mlp outputs.
embd_pdrop (`int`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio after computing the attention scores.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
original_max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model was trained with. This is used to determine the size of the
original RoPE embeddings when using long scaling.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon value used for the RMSNorm.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 32000):
The id of the "end-of-sequence" token.
pad_token_id (`int`, *optional*, defaults to 32000):
The id of the padding token.
sliding_window (`int`, *optional*):
Sliding window attention window size. If `None`, no sliding window is applied.
Example:
```python
>>> from transformers import Phi3Model, Phi3Config
>>> # Initializing a Phi-3 style configuration
>>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
>>> # Initializing a model from the configuration
>>> model = Phi3Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phi3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32064,
hidden_size=3072,
intermediate_size=8192,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act="silu",
max_position_embeddings=4096,
original_max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
bos_token_id=1,
eos_token_id=32000,
pad_token_id=32000,
sliding_window=None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.sliding_window = sliding_window
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)

File diff suppressed because it is too large Load Diff

View File

@ -6752,6 +6752,44 @@ class PhiPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Phi3ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi3ForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi3ForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi3Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi3PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

View File

@ -0,0 +1,474 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Phi-3 model. """
import unittest
from parameterized import parameterized
from transformers import Phi3Config, is_torch_available, set_seed
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
AutoTokenizer,
Phi3ForCausalLM,
Phi3ForSequenceClassification,
Phi3ForTokenClassification,
Phi3Model,
)
class Phi3ModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return Phi3Config(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Phi3
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = Phi3Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Phi3
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = Phi3Model(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Phi3
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = Phi3ForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Phi3
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = Phi3ForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(Phi3Model, Phi3ForCausalLM, Phi3ForSequenceClassification, Phi3ForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (Phi3ForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": Phi3Model,
"text-classification": Phi3ForSequenceClassification,
"text-generation": Phi3ForCausalLM,
"token-classification": Phi3ForTokenClassification,
"zero-shot": Phi3ForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
return True
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phi3
def setUp(self):
self.model_tester = Phi3ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Phi3Config, hidden_size=37)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phi3,llama->phi3
def test_phi3_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = Phi3ForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phi3,llama->phi3
def test_phi3_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "single_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = Phi3ForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phi3,llama->phi3
def test_phi3_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = Phi3ForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@parameterized.expand([("su",), ("yarn",)])
def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = Phi3Model(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
n_factors = config.hidden_size // config.num_attention_heads // 2
config.rope_scaling = {
"type": scaling_type,
"short_factor": [5.0 for _ in range(n_factors)],
"long_factor": [5.0 for _ in range(n_factors)],
}
scaled_model = Phi3Model(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Scaling changes the RoPE embeddings, both for the short and long outputs
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@slow
@require_torch
class Phi3IntegrationTest(unittest.TestCase):
def test_model_phi3_mini_4k_instruct_logits(self):
input_ids = {
"input_ids": torch.tensor(
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
)
}
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct").to(torch_device)
model.eval()
output = model(**input_ids).logits
EXPECTED_OUTPUT = torch.tensor([[ 0.9979, -1.9449, -2.5613, -2.2110, -0.9323, -2.2726, -3.2468, -2.0122,-1.0021, -1.2764, -1.0876, -1.2358, 3.9385, 6.2152, -0.3695, -2.3285,-1.2907, -1.8238, -1.9941, -2.2098, -0.6923, -1.6793, -1.1660, -2.0469,-0.7369, -1.4101, -1.4091, -3.1694, -1.8383, -1.1952],[ 3.0525, 1.9178, 3.7016, 0.9263, 0.3397, 1.9584, 2.1347, 0.3482, 1.3773, 0.2153, 0.2798, 0.8360, 9.0936, 11.4944, -0.3575, -0.9442,-0.1246, 1.3869, 0.9846, 1.7243, 0.9150, 1.0823, 0.4313, 1.5742, 0.2566, -0.1401, -1.3019, 0.4967, 0.6941, 0.7214]]).to(torch_device) # fmt: skip
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-4, rtol=1e-4))
def test_phi3_mini_4k_instruct_generation(self):
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(inputs, max_new_tokens=32)
output_text = tokenizer.batch_decode(outputs)
EXPECTED_OUTPUT = [
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Absolutely! Bananas and dragonfruits are both delicious fruits that can be combined in various ways to create tasty and nutrit"
]
self.assertListEqual(output_text, EXPECTED_OUTPUT)
def test_model_phi3_mini_128k_instruct_logits(self):
input_ids = {
"input_ids": torch.tensor(
[[1212, 318, 281, 1672, 2643, 290, 428, 318, 257, 1332]], dtype=torch.long, device=torch_device
)
}
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct").to(torch_device)
model.eval()
output = model(**input_ids).logits
EXPECTED_OUTPUT = torch.tensor([[ 1.8478, -0.5709, -1.6792, -1.2133, -0.7809, -0.8817, -2.0969, -1.1191,-0.7731, -1.0483, -0.5961, -1.3067, 3.1325, 6.9442, -0.4803, -0.9154,-1.3085, -1.0822, -1.1433, -0.7660, -0.8531, -0.9150, -0.6179, -1.6153,-0.2239, -1.3207, -1.1187, -2.4795, -1.4733, -0.4931],[ 3.5839, 2.4722, 3.7130, 1.2032, 0.7356, 2.7777, 2.5256, 0.9157, 1.6431, 0.3533, 0.5100, 1.3512, 8.9873, 10.9815, 0.3530, 0.1473, 0.2051, 1.8553, 1.5988, 2.2268, 1.1897, 1.2829, 0.7894, 1.8895, 0.7666, 0.4122, -0.9316, 0.9936, 1.2722, 0.8263]]).to(torch_device) # fmt: skip
self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-4, rtol=1e-4))
def test_phi3_mini_128k_instruct_generation(self):
model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-128k-instruct")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-128k-instruct")
messages = [
{
"role": "system",
"content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.",
},
{"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(inputs, max_new_tokens=32)
output_text = tokenizer.batch_decode(outputs)
EXPECTED_OUTPUT = [
"<s><|system|> You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.<|end|><|user|> Can you provide ways to eat combinations of bananas and dragonfruits?<|end|><|assistant|> Certainly! Bananas and dragonfruits can be combined in various delicious and healthy ways. Here are some ideas:\n\n1."
]
self.assertListEqual(output_text, EXPECTED_OUTPUT)

View File

@ -91,6 +91,6 @@ if __name__ == "__main__":
find_new_model = reg.findall(x)
if len(find_new_model) > 0:
new_model = find_new_model[0]
# It's unlikely we have 2 new modeling files in a pull request.
break
# It's unlikely we have 2 new modeling files in a pull request.
break
print(new_model)