Add JetMoE model (#30005)
* init jetmoe code * update archive maps * remove flax import * fix import error * update README * ruff fix * update readme * fix * update config * fix issue * merge files * fix model bug * fix test * auto fix * model size * add comments * fix form * add flash attention support * fix attention head number * fix init * fix support list * sort auto mapping * fix test * fix docs * update test * fix test * fix test * change variable name * fix config * fix init * update format * clean code * fix config * fix config * change default config * update config * fix issues * update formate * update config argument * update format * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * change to mixtral aux loss * change to cache_position * debug * fix bugs * debug * fix format * fix format * fix copy * fix format * fix format * fix sort * fix sort * fix sort * add copy comment * add copy from * remove debug code * revert readme update * add copy * debug * remove debug code * fix flash attention * add comments * clean code * clean format * fix format * fix format * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * change variable name * add copied from * fix variable name * remove deprecated functinos * sync to llama implementation * fix format * fix copy * fix format * update format * remove repr * add comment for moe weight * fix copy * Update src/transformers/models/jetmoe/configuration_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add comments and reformat config * fix format * fix format * fix format * update test * update doc string in config * Update src/transformers/models/jetmoe/modeling_jetmoe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update config doc * update attention cache * fix format * fix copy --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
d84f34ad77
commit
ccdabc5642
|
@ -386,6 +386,8 @@
|
|||
title: I-BERT
|
||||
- local: model_doc/jamba
|
||||
title: Jamba
|
||||
- local: model_doc/jetmoe
|
||||
title: JetMoe
|
||||
- local: model_doc/jukebox
|
||||
title: Jukebox
|
||||
- local: model_doc/led
|
||||
|
|
|
@ -166,6 +166,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
|
||||
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |
|
||||
| [Jamba](model_doc/jamba) | ✅ | ❌ | ❌ |
|
||||
| [JetMoe](model_doc/jetmoe) | ✅ | ❌ | ❌ |
|
||||
| [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ |
|
||||
| [KOSMOS-2](model_doc/kosmos-2) | ✅ | ❌ | ❌ |
|
||||
| [LayoutLM](model_doc/layoutlm) | ✅ | ✅ | ❌ |
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
<!--Copyright 2024 JetMoe team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# JetMoe
|
||||
|
||||
## Overview
|
||||
|
||||
**JetMoe-8B** is an 8B Mixture-of-Experts (MoE) language model developed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ) and [MyShell](https://myshell.ai/).
|
||||
JetMoe project aims to provide a LLaMA2-level performance and efficient language model with a limited budget.
|
||||
To achieve this goal, JetMoe uses a sparsely activated architecture inspired by the [ModuleFormer](https://arxiv.org/abs/2306.04640).
|
||||
Each JetMoe block consists of two MoE layers: Mixture of Attention Heads and Mixture of MLP Experts.
|
||||
Given the input tokens, it activates a subset of its experts to process them.
|
||||
This sparse activation schema enables JetMoe to achieve much better training throughput than similar size dense models.
|
||||
The training throughput of JetMoe-8B is around 100B tokens per day on a cluster of 96 H100 GPUs with a straightforward 3-way pipeline parallelism strategy.
|
||||
|
||||
This model was contributed by [Yikang Shen](https://huggingface.co/YikangS).
|
||||
|
||||
|
||||
## JetMoeConfig
|
||||
|
||||
[[autodoc]] JetMoeConfig
|
||||
|
||||
## JetMoeModel
|
||||
|
||||
[[autodoc]] JetMoeModel
|
||||
- forward
|
||||
|
||||
## JetMoeForCausalLM
|
||||
|
||||
[[autodoc]] JetMoeForCausalLM
|
||||
- forward
|
||||
|
||||
## JetMoeForSequenceClassification
|
||||
|
||||
[[autodoc]] JetMoeForSequenceClassification
|
||||
- forward
|
|
@ -50,6 +50,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
|||
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
|
||||
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
|
||||
|
@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
|||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
|
||||
|
|
|
@ -448,6 +448,7 @@ _import_structure = {
|
|||
"InstructBlipVisionConfig",
|
||||
],
|
||||
"models.jamba": ["JambaConfig"],
|
||||
"models.jetmoe": ["JetMoeConfig"],
|
||||
"models.jukebox": [
|
||||
"JukeboxConfig",
|
||||
"JukeboxPriorConfig",
|
||||
|
@ -2202,6 +2203,14 @@ else:
|
|||
"JambaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.jetmoe"].extend(
|
||||
[
|
||||
"JetMoeForCausalLM",
|
||||
"JetMoeForSequenceClassification",
|
||||
"JetMoeModel",
|
||||
"JetMoePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.jukebox"].extend(
|
||||
[
|
||||
"JukeboxModel",
|
||||
|
@ -4973,6 +4982,7 @@ if TYPE_CHECKING:
|
|||
InstructBlipVisionConfig,
|
||||
)
|
||||
from .models.jamba import JambaConfig
|
||||
from .models.jetmoe import JetMoeConfig
|
||||
from .models.jukebox import (
|
||||
JukeboxConfig,
|
||||
JukeboxPriorConfig,
|
||||
|
@ -6591,6 +6601,12 @@ if TYPE_CHECKING:
|
|||
JambaModel,
|
||||
JambaPreTrainedModel,
|
||||
)
|
||||
from .models.jetmoe import (
|
||||
JetMoeForCausalLM,
|
||||
JetMoeForSequenceClassification,
|
||||
JetMoeModel,
|
||||
JetMoePreTrainedModel,
|
||||
)
|
||||
from .models.jukebox import (
|
||||
JukeboxModel,
|
||||
JukeboxPreTrainedModel,
|
||||
|
|
|
@ -117,6 +117,7 @@ from . import (
|
|||
informer,
|
||||
instructblip,
|
||||
jamba,
|
||||
jetmoe,
|
||||
jukebox,
|
||||
kosmos2,
|
||||
layoutlm,
|
||||
|
|
|
@ -128,6 +128,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||
("informer", "InformerConfig"),
|
||||
("instructblip", "InstructBlipConfig"),
|
||||
("jamba", "JambaConfig"),
|
||||
("jetmoe", "JetMoeConfig"),
|
||||
("jukebox", "JukeboxConfig"),
|
||||
("kosmos-2", "Kosmos2Config"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
|
@ -399,6 +400,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||
("informer", "Informer"),
|
||||
("instructblip", "InstructBLIP"),
|
||||
("jamba", "Jamba"),
|
||||
("jetmoe", "JetMoe"),
|
||||
("jukebox", "Jukebox"),
|
||||
("kosmos-2", "KOSMOS-2"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
|
|
|
@ -125,6 +125,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||
("imagegpt", "ImageGPTModel"),
|
||||
("informer", "InformerModel"),
|
||||
("jamba", "JambaModel"),
|
||||
("jetmoe", "JetMoeModel"),
|
||||
("jukebox", "JukeboxModel"),
|
||||
("kosmos-2", "Kosmos2Model"),
|
||||
("layoutlm", "LayoutLMModel"),
|
||||
|
@ -458,6 +459,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("jamba", "JambaForCausalLM"),
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
("mamba", "MambaForCausalLM"),
|
||||
("marian", "MarianForCausalLM"),
|
||||
|
@ -860,6 +862,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||
("gptj", "GPTJForSequenceClassification"),
|
||||
("ibert", "IBertForSequenceClassification"),
|
||||
("jamba", "JambaForSequenceClassification"),
|
||||
("jetmoe", "JetMoeForSequenceClassification"),
|
||||
("layoutlm", "LayoutLMForSequenceClassification"),
|
||||
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
|
||||
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
||||
|
|
|
@ -211,6 +211,13 @@ else:
|
|||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"jetmoe",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("jukebox", ("JukeboxTokenizer", None)),
|
||||
(
|
||||
"kosmos-2",
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_jetmoe": ["JetMoeConfig"],
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_jetmoe"] = [
|
||||
"JetMoeForCausalLM",
|
||||
"JetMoeModel",
|
||||
"JetMoePreTrainedModel",
|
||||
"JetMoeForSequenceClassification",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_jetmoe import JetMoeConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_jetmoe import (
|
||||
JetMoeForCausalLM,
|
||||
JetMoeForSequenceClassification,
|
||||
JetMoeModel,
|
||||
JetMoePreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
|
@ -0,0 +1,149 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""JetMoe model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class JetMoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`JetMoeModel`]. It is used to instantiate a
|
||||
JetMoe model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a configuration of the JetMoe-4B.
|
||||
|
||||
[jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the JetMoe model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`JetMoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each key and value in the Transformer encoder.
|
||||
kv_channels (`int`, *optional*, defaults to 128):
|
||||
Defines the number of channels for the key and value tensors.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimension of the MLP representations.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
The maximum sequence length that this model might ever be used with. JetMoe's attention allows sequence of
|
||||
up to 4096 tokens.
|
||||
activation_function (`string`, *optional*, defaults to `"silu"`):
|
||||
Defines the activation function for MLP experts.
|
||||
num_local_experts (`int`, *optional*, defaults to 8):
|
||||
Defines the number of experts in the MoE and MoA.
|
||||
num_experts_per_tok (`int, *optional*, defaults to 2):
|
||||
The number of experts to route per-token and for MoE and MoA.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
allow the model to output the auxiliary loss.
|
||||
aux_loss_coef (`float`, *optional*, defaults to 0.01):
|
||||
The coefficient for the auxiliary loss.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.01):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import JetMoeModel, JetMoeConfig
|
||||
|
||||
>>> # Initializing a JetMoe 4B style configuration
|
||||
>>> configuration = JetMoeConfig()
|
||||
|
||||
>>> # Initializing a model from the JetMoe 4B style configuration
|
||||
>>> model = JetMoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "jetmoe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=2048,
|
||||
num_hidden_layers=12,
|
||||
num_key_value_heads=16,
|
||||
kv_channels=128,
|
||||
intermediate_size=5632,
|
||||
max_position_embeddings=4096,
|
||||
activation_function="silu",
|
||||
num_local_experts=8,
|
||||
num_experts_per_tok=2,
|
||||
output_router_logits=False,
|
||||
aux_loss_coef=0.01,
|
||||
use_cache=True,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=10000.0,
|
||||
rms_norm_eps=1e-6,
|
||||
initializer_range=0.01,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
if num_experts_per_tok > num_local_experts:
|
||||
raise ValueError("`num_experts_per_tok` must be less than or equal to `num_local_experts`")
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_key_value_heads * num_experts_per_tok
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.kv_channels = kv_channels
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.activation_function = activation_function
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.output_router_logits = output_router_logits
|
||||
self.aux_loss_coef = aux_loss_coef
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||
)
|
File diff suppressed because it is too large
Load Diff
|
@ -4334,6 +4334,34 @@ class JambaPreTrainedModel(metaclass=DummyObject):
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class JetMoeForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class JetMoeForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class JetMoeModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class JetMoePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class JukeboxModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
|
|
@ -0,0 +1,536 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch JetMoe model."""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_flaky,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
JetMoeForCausalLM,
|
||||
JetMoeForSequenceClassification,
|
||||
JetMoeModel,
|
||||
)
|
||||
|
||||
|
||||
class JetMoeModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_key_value_heads=2,
|
||||
kv_channels=8,
|
||||
intermediate_size=37,
|
||||
hidden_act="silu",
|
||||
num_local_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_key_value_heads * num_experts_per_tok
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.ones(self.batch_size, self.seq_length).to(torch_device)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return JetMoeConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
kv_channels=self.kv_channels,
|
||||
intermediate_size=self.intermediate_size,
|
||||
activation_function=self.hidden_act,
|
||||
num_local_experts=self.num_local_experts,
|
||||
num_experts_per_tok=self.num_experts_per_tok,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = JetMoeModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = JetMoeModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = JetMoeForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = JetMoeForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(JetMoeModel, JetMoeForCausalLM, JetMoeForSequenceClassification) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (JetMoeForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": JetMoeModel,
|
||||
"text-classification": JetMoeForSequenceClassification,
|
||||
"text-generation": JetMoeForCausalLM,
|
||||
"zero-shot": JetMoeForSequenceClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
test_mismatched_shapes = False
|
||||
test_cpu_offload = False
|
||||
test_disk_offload_bin = False
|
||||
test_disk_offload_safetensors = False
|
||||
|
||||
# TODO: @Fxmarty
|
||||
@is_flaky(max_attempts=3, description="flaky on some models.")
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
super().test_eager_matches_sdpa_generate()
|
||||
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = JetMoeModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=JetMoeConfig, common_properties=["hidden_size", "num_hidden_layers"]
|
||||
)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with llama->jetmoe, Llama->JetMoe
|
||||
def test_jetmoe_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = JetMoeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with llama->jetmoe, Llama->JetMoe
|
||||
def test_jetmoe_sequence_classification_model_for_single_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "single_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = JetMoeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with llama->jetmoe, Llama->JetMoe
|
||||
def test_jetmoe_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "multi_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = JetMoeForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip("JetMoe buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("JetMoe uses MoA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||
|
||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_use_cache(self):
|
||||
import torch
|
||||
|
||||
max_new_tokens = 30
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
# NOTE: JetMoe apparently does not support right padding + use_cache with FA2.
|
||||
dummy_attention_mask[:, -1] = 1
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
# Just test that a large cache works as expected
|
||||
_ = model.generate(
|
||||
dummy_input,
|
||||
attention_mask=dummy_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
self.skipTest("JetMoe flash attention does not support right padding")
|
||||
|
||||
|
||||
@require_torch
|
||||
class JetMoeIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_model_8b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[0.2507, -2.7073, -1.3445, -1.9363, -1.7216, -1.7370, -1.9054, -1.9792]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([-3.3689, 5.9006, 5.7450, -1.7012, -4.7072, -4.7071, -4.7071, -4.7071, -4.7072, -4.7072, -4.7072, -4.7071, 3.8321, 9.1746, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071]) # fmt: skip
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_model_8b_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=10, temperature=0)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_model_8b_batched_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"""My favourite condiment is ....\nI love ketchup. I love""",
|
||||
"""My favourite 2018 Christmas present was a new pair""",
|
||||
]
|
||||
prompt = [
|
||||
"My favourite condiment is ",
|
||||
"My favourite ",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
||||
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
|
||||
print(input_ids)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(**input_ids, max_new_tokens=10, temperature=0)
|
||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
print(text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
Loading…
Reference in New Issue