diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d654478970..cd60486e16 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -386,6 +386,8 @@ title: I-BERT - local: model_doc/jamba title: Jamba + - local: model_doc/jetmoe + title: JetMoe - local: model_doc/jukebox title: Jukebox - local: model_doc/led diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 9adb669e2c..943f98c218 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -166,6 +166,7 @@ Flax), PyTorch, and/or TensorFlow. | [Informer](model_doc/informer) | ✅ | ❌ | ❌ | | [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ | | [Jamba](model_doc/jamba) | ✅ | ❌ | ❌ | +| [JetMoe](model_doc/jetmoe) | ✅ | ❌ | ❌ | | [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ | | [KOSMOS-2](model_doc/kosmos-2) | ✅ | ❌ | ❌ | | [LayoutLM](model_doc/layoutlm) | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/jetmoe.md b/docs/source/en/model_doc/jetmoe.md new file mode 100644 index 0000000000..87f99c6f99 --- /dev/null +++ b/docs/source/en/model_doc/jetmoe.md @@ -0,0 +1,49 @@ + + +# JetMoe + +## Overview + +**JetMoe-8B** is an 8B Mixture-of-Experts (MoE) language model developed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ) and [MyShell](https://myshell.ai/). +JetMoe project aims to provide a LLaMA2-level performance and efficient language model with a limited budget. +To achieve this goal, JetMoe uses a sparsely activated architecture inspired by the [ModuleFormer](https://arxiv.org/abs/2306.04640). +Each JetMoe block consists of two MoE layers: Mixture of Attention Heads and Mixture of MLP Experts. +Given the input tokens, it activates a subset of its experts to process them. +This sparse activation schema enables JetMoe to achieve much better training throughput than similar size dense models. +The training throughput of JetMoe-8B is around 100B tokens per day on a cluster of 96 H100 GPUs with a straightforward 3-way pipeline parallelism strategy. + +This model was contributed by [Yikang Shen](https://huggingface.co/YikangS). + + +## JetMoeConfig + +[[autodoc]] JetMoeConfig + +## JetMoeModel + +[[autodoc]] JetMoeModel + - forward + +## JetMoeForCausalLM + +[[autodoc]] JetMoeForCausalLM + - forward + +## JetMoeForSequenceClassification + +[[autodoc]] JetMoeForSequenceClassification + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index c8e99c1d43..151ca765a2 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -50,6 +50,7 @@ FlashAttention-2 is currently supported for the following architectures: * [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel) * [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) +* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava) @@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) +* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 502c561cc8..16970ac22a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -448,6 +448,7 @@ _import_structure = { "InstructBlipVisionConfig", ], "models.jamba": ["JambaConfig"], + "models.jetmoe": ["JetMoeConfig"], "models.jukebox": [ "JukeboxConfig", "JukeboxPriorConfig", @@ -2202,6 +2203,14 @@ else: "JambaPreTrainedModel", ] ) + _import_structure["models.jetmoe"].extend( + [ + "JetMoeForCausalLM", + "JetMoeForSequenceClassification", + "JetMoeModel", + "JetMoePreTrainedModel", + ] + ) _import_structure["models.jukebox"].extend( [ "JukeboxModel", @@ -4973,6 +4982,7 @@ if TYPE_CHECKING: InstructBlipVisionConfig, ) from .models.jamba import JambaConfig + from .models.jetmoe import JetMoeConfig from .models.jukebox import ( JukeboxConfig, JukeboxPriorConfig, @@ -6591,6 +6601,12 @@ if TYPE_CHECKING: JambaModel, JambaPreTrainedModel, ) + from .models.jetmoe import ( + JetMoeForCausalLM, + JetMoeForSequenceClassification, + JetMoeModel, + JetMoePreTrainedModel, + ) from .models.jukebox import ( JukeboxModel, JukeboxPreTrainedModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f07a4fc588..c0fd6cc6d2 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -117,6 +117,7 @@ from . import ( informer, instructblip, jamba, + jetmoe, jukebox, kosmos2, layoutlm, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index f5569eb1cb..634399ad64 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -128,6 +128,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("informer", "InformerConfig"), ("instructblip", "InstructBlipConfig"), ("jamba", "JambaConfig"), + ("jetmoe", "JetMoeConfig"), ("jukebox", "JukeboxConfig"), ("kosmos-2", "Kosmos2Config"), ("layoutlm", "LayoutLMConfig"), @@ -399,6 +400,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("informer", "Informer"), ("instructblip", "InstructBLIP"), ("jamba", "Jamba"), + ("jetmoe", "JetMoe"), ("jukebox", "Jukebox"), ("kosmos-2", "KOSMOS-2"), ("layoutlm", "LayoutLM"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f00c223d2e..cb3e84efc2 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -125,6 +125,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), ("jamba", "JambaModel"), + ("jetmoe", "JetMoeModel"), ("jukebox", "JukeboxModel"), ("kosmos-2", "Kosmos2Model"), ("layoutlm", "LayoutLMModel"), @@ -458,6 +459,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("jamba", "JambaForCausalLM"), + ("jetmoe", "JetMoeForCausalLM"), ("llama", "LlamaForCausalLM"), ("mamba", "MambaForCausalLM"), ("marian", "MarianForCausalLM"), @@ -860,6 +862,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("gptj", "GPTJForSequenceClassification"), ("ibert", "IBertForSequenceClassification"), ("jamba", "JambaForSequenceClassification"), + ("jetmoe", "JetMoeForSequenceClassification"), ("layoutlm", "LayoutLMForSequenceClassification"), ("layoutlmv2", "LayoutLMv2ForSequenceClassification"), ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c883c2e836..8bb8757cb6 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -211,6 +211,13 @@ else: "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "jetmoe", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("jukebox", ("JukeboxTokenizer", None)), ( "kosmos-2", diff --git a/src/transformers/models/jetmoe/__init__.py b/src/transformers/models/jetmoe/__init__.py new file mode 100644 index 0000000000..48ac583a6a --- /dev/null +++ b/src/transformers/models/jetmoe/__init__.py @@ -0,0 +1,56 @@ +# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_jetmoe": ["JetMoeConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_jetmoe"] = [ + "JetMoeForCausalLM", + "JetMoeModel", + "JetMoePreTrainedModel", + "JetMoeForSequenceClassification", + ] + +if TYPE_CHECKING: + from .configuration_jetmoe import JetMoeConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_jetmoe import ( + JetMoeForCausalLM, + JetMoeForSequenceClassification, + JetMoeModel, + JetMoePreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/jetmoe/configuration_jetmoe.py b/src/transformers/models/jetmoe/configuration_jetmoe.py new file mode 100644 index 0000000000..c6913faee1 --- /dev/null +++ b/src/transformers/models/jetmoe/configuration_jetmoe.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JetMoe model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class JetMoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`JetMoeModel`]. It is used to instantiate a + JetMoe model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a configuration of the JetMoe-4B. + + [jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the JetMoe model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`JetMoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each key and value in the Transformer encoder. + kv_channels (`int`, *optional*, defaults to 128): + Defines the number of channels for the key and value tensors. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. JetMoe's attention allows sequence of + up to 4096 tokens. + activation_function (`string`, *optional*, defaults to `"silu"`): + Defines the activation function for MLP experts. + num_local_experts (`int`, *optional*, defaults to 8): + Defines the number of experts in the MoE and MoA. + num_experts_per_tok (`int, *optional*, defaults to 2): + The number of experts to route per-token and for MoE and MoA. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. + aux_loss_coef (`float`, *optional*, defaults to 0.01): + The coefficient for the auxiliary loss. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + initializer_range (`float`, *optional*, defaults to 0.01): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import JetMoeModel, JetMoeConfig + + >>> # Initializing a JetMoe 4B style configuration + >>> configuration = JetMoeConfig() + + >>> # Initializing a model from the JetMoe 4B style configuration + >>> model = JetMoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jetmoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=2048, + num_hidden_layers=12, + num_key_value_heads=16, + kv_channels=128, + intermediate_size=5632, + max_position_embeddings=4096, + activation_function="silu", + num_local_experts=8, + num_experts_per_tok=2, + output_router_logits=False, + aux_loss_coef=0.01, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + rms_norm_eps=1e-6, + initializer_range=0.01, + attention_dropout=0.0, + **kwargs, + ): + if num_experts_per_tok > num_local_experts: + raise ValueError("`num_experts_per_tok` must be less than or equal to `num_local_experts`") + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_key_value_heads * num_experts_per_tok + self.num_key_value_heads = num_key_value_heads + self.kv_channels = kv_channels + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.activation_function = activation_function + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.output_router_logits = output_router_logits + self.aux_loss_coef = aux_loss_coef + self.use_cache = use_cache + self.initializer_range = initializer_range + self.attention_dropout = attention_dropout + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.rope_theta = rope_theta + self.rms_norm_eps = rms_norm_eps + + super().__init__( + bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs + ) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py new file mode 100644 index 0000000000..15571ce9a3 --- /dev/null +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -0,0 +1,1603 @@ +# coding=utf-8 +# Copyright 2024 JetMoe AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch JetMoe model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_jetmoe import JetMoeConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "jetmoe" +_CONFIG_FOR_DOC = "JetMoeConfig" + + +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class JetMoeParallelExperts(nn.Module): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + Initialize the JetMoeParallelExperts module. + The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's comptible with + many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and + [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the + [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py) + used in vllm. + + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def forward(self, inputs, expert_size): + """ + Forward pass of the JetMoeParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + output_list.append(F.linear(input_list[i], self.weight[i])) + results = torch.cat(output_list, dim=0) + return results + + +class JetMoeTopKGating(nn.Module): + def __init__(self, input_size: int, num_experts: int, top_k: int): + """ + Initialize the top-k gating mechanism. + + Args: + input_size (`int`): + Size of the input. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top experts to select. + """ + super().__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.top_k = top_k + + self.layer = nn.Linear(input_size, num_experts, bias=False) + + def forward(self, hidden_states): + # compute the top_k routing decision + logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts] + top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k] + top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] + + # compute number of input given to each expert + zeros = torch.zeros( + [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device + ) # [num_tokens, num_experts] + gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] + expert_size = gates.long().sum(0) # [num_experts,] + expert_size = expert_size.tolist() + + # sort and group input tokens according to expert assignment + top_k_experts = top_k_indices.flatten() # [num_tokens * top_k] + _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k] + batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k] + + # gather the gate values for grouped input tokens + top_k_gates = top_k_gates.flatten() # [num_tokens * top_k] + batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k] + + return index_sorted_experts, batch_index, batch_gates, expert_size, logits + + +class JetMoeMoE(nn.Module): + """ + A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: JetMoeConfig): + super(JetMoeMoE, self).__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.intermediate_size + self.activation = ACT2FN[config.activation_function] + self.bias = torch.nn.Parameter(torch.empty(self.input_size)) + self.input_linear = JetMoeParallelExperts(config.num_local_experts, self.input_size, self.hidden_size * 2) + self.output_linear = JetMoeParallelExperts(config.num_local_experts, self.hidden_size, self.input_size) + + self.router = JetMoeTopKGating( + input_size=self.input_size, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + ) + + def forward(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + Tensor: + Router logits. + """ + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) + _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + + expert_inputs = layer_input[batch_index] + hidden_states = self.input_linear(expert_inputs, expert_size) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_size) + + expert_outputs = expert_outputs * batch_gates[:, None] + + zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) + layer_output = zeros.index_add(0, batch_index, expert_outputs) + layer_output = layer_output.view(bsz, length, self.input_size) + layer_output = layer_output + self.bias + return layer_output, router_logits + + +class JetMoeMoA(nn.Module): + """ + A Sparsely gated mixture of attention layer with pairs of query- and output-projections as experts. + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: JetMoeConfig): + super(JetMoeMoA, self).__init__() + + self.num_experts = config.num_local_experts + self.input_size = config.hidden_size + self.hidden_size = config.kv_channels * config.num_key_value_heads + self.top_k = config.num_experts_per_tok + self.bias = torch.nn.Parameter(torch.empty(self.input_size)) + + self.input_linear = JetMoeParallelExperts(self.num_experts, self.input_size, self.hidden_size) + self.output_linear = JetMoeParallelExperts(self.num_experts, self.hidden_size, self.input_size) + + self.router = JetMoeTopKGating( + input_size=self.input_size, + num_experts=self.num_experts, + top_k=self.top_k, + ) + + def map(self, layer_input): + """ + Map inputs to attention experts according to routing decision and compute query projection inside each experts. + """ + + # Compute gating topology + bsz, length, emb_size = layer_input.size() + layer_input = layer_input.reshape(-1, emb_size) # [bsz * length, emb_size] + index_sorted_experts, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + topo_info = (index_sorted_experts, batch_index, batch_gates, expert_size) + + # Group inputs according to topology and compute query projection + expert_inputs = layer_input[batch_index] # [bsz * length * top_k, emb_size] + expert_outputs = self.input_linear(expert_inputs, expert_size) # [bsz * length * top_k, hidden_size] + + # Ungroup queries back to original order + zeros = torch.zeros( + (bsz * length * self.top_k, self.hidden_size), dtype=expert_outputs.dtype, device=expert_outputs.device + ) + layer_output = zeros.index_add(0, index_sorted_experts, expert_outputs) + layer_output = layer_output.view(bsz, length, self.top_k, -1) # [bsz, length, top_k, hidden_size] + return layer_output, router_logits, topo_info + + def reduce(self, layer_input, topo_info): + """ + Compute output projection inside each attention experts and merge the outputs of different experts. + """ + bsz, length, k, hidden_size = layer_input.size() + layer_input = layer_input.reshape(-1, hidden_size) # [bsz * length * k, hidden_size] + index_sorted_experts, batch_index, batch_gates, expert_size = topo_info + + # Group inputs according to topology and compute output projection + expert_inputs = layer_input[index_sorted_experts] # [bsz * length * top_k, hidden_size] + expert_outputs = self.output_linear(expert_inputs, expert_size) # [bsz * length * top_k, emb_size] + + # Apply gates to attention expert outputs + expert_outputs = expert_outputs * batch_gates[:, None] + + # Ungroup and merge outputs to original order + zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device) + layer_output = zeros.index_add(0, batch_index, expert_outputs) + layer_output = layer_output.view(bsz, length, self.input_size) + layer_output = layer_output + self.bias + return layer_output + + def forward(self, layer_input): + raise NotImplementedError("This module doesn't support call and forward.") + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->JetMoe +class JetMoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + JetMoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe +class JetMoeRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class JetMoeAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + """ + + def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): + """ + Initialize the JetMoeAttention module. + + Args: + config: + Configuration object with model hyperparameters. + layer_idx: + Index of the layer in the model. + """ + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.is_causal = True + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.top_k = config.num_experts_per_tok + self.attention_dropout = config.attention_dropout + self.kv_projection_size = config.kv_channels * config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_heads = config.num_attention_heads + self.head_dim = config.kv_channels + + self.experts = JetMoeMoA(config) + + self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) + + self.rotary_emb = JetMoeRotaryEmbedding( + config.kv_channels, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states, router_logits, topo_info = self.experts.map(hidden_states) + key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads for top-k attention experts + key_states = key_states.repeat(1, self.top_k, 1, 1) + value_states = value_states.repeat(1, self.top_k, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) + + attn_output = self.experts.reduce(attn_output, topo_info) + attn_output = attn_output.view(bsz, q_len, -1) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, router_logits + + +class JetMoeSdpaAttention(JetMoeAttention): + """ + JetMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `JetMoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from JetMoeAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]], Optional[torch.Tensor]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "JetMoeModel is using JetMoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states, router_logits, topo_info = self.experts.map(hidden_states) + key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads for top-k attention experts + key_states = key_states.repeat(1, self.top_k, 1, 1) + value_states = value_states.repeat(1, self.top_k, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather + # relying on the `is_causal` argument. + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) + + attn_output = self.experts.reduce(attn_output, topo_info) + attn_output = attn_output.view(bsz, q_len, -1) + + return attn_output, None, past_key_value, router_logits + + +class JetMoeFlashAttention2(JetMoeAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + """ + Forward pass of the JetMoeAttention module. + + Args: + hidden_states (Optional[torch.FloatTensor]): Input hidden states. + attention_mask (Optional[torch.FloatTensor]): Attention mask. + layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. + use_cache (Optional[bool]): Whether to use cached states. + output_attentions (Optional[bool]): Whether to output attention weights. + cache_position (Optional[torch.LongTensor]): Position of the cache. + + Returns: + Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs. + """ + output_attentions = False + bsz, q_len, hidden_size = hidden_states.size() + + # calculate query, key, values + query_states, router_logits, topo_info = self.experts.map(hidden_states) + key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads for top-k attention experts + key_states = key_states.repeat(1, self.top_k, 1, 1) + value_states = value_states.repeat(1, self.top_k, 1, 1) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.kv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ).to(input_dtype) + + # output projection + attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) + attn_output = self.experts.reduce(attn_output, topo_info) + attn_output = attn_output.view(bsz, q_len, hidden_size) # re-assemble all head outputs side by side + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, router_logits + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +JETMOE_ATTENTION_CLASSES = { + "eager": JetMoeAttention, + "flash_attention_2": JetMoeFlashAttention2, + "sdpa": JetMoeSdpaAttention, +} + + +class JetMoeBlock(nn.Module): + def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): + """ + Initialize the JetMoeBlock module. + + Args: + config: + Configuration object with model hyperparameters. + """ + super().__init__() + self.input_layernorm = JetMoeRMSNorm(config.hidden_size) + self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.post_attention_layernorm = JetMoeRMSNorm(config.hidden_size) + + self.mlp = JetMoeMoE(config) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + # Self Attention + attn_output, self_attn_weights, present_key_value, attn_router_logits = self.self_attention( + hidden_states=self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states + attn_output + x_mlp, mlp_router_logits = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = hidden_states + x_mlp + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += attn_router_logits, mlp_router_logits + + return outputs + + +class JetMoePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = JetMoeConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = False + _no_split_modules = ["JetMoeBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, JetMoeParallelExperts): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, JetMoeMoA): + module.bias.data.zero_() + elif isinstance(module, JetMoeMoE): + module.bias.data.zero_() + + +JETMOE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`JetMoeConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +JETMOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare JetMoe Model outputting raw hidden-states without any specific head on top.", + JETMOE_START_DOCSTRING, +) +class JetMoeModel(JetMoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JetMoeBlock`] + + Args: + config: + JetMoeConfig + """ + + def __init__(self, config: JetMoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([JetMoeBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self._attn_implementation = config._attn_implementation + self.norm = JetMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + batch_size = inputs_embeds.shape[0] + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_ids, + past_key_values, + causal_mask, + output_attentions, + output_router_logits, + use_cache, + use_reentrant=False, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-2], layer_outputs[-1]) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class JetMoeForCausalLM(JetMoePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = JetMoeModel(config) + self.vocab_size = config.vocab_size + self.aux_loss_coef = config.aux_loss_coef + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.tie_word_embeddings = config.tie_word_embeddings + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Ensure tensors are on the same device + shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + output_router_logits=False, + **kwargs, + ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) + has_static_cache = past_key_values is not None + + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + else: + cache_position = cache_position[-input_length:] + + if has_static_cache: + past_key_values = None + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The JetMoe Model transformer with a sequence classification head on top (linear layer). + + [`JetMoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + JETMOE_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->JetMoe, LLAMA->JETMOE +class JetMoeForSequenceClassification(JetMoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = JetMoeModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(JETMOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 7d9813d150..9570e044d3 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4334,6 +4334,34 @@ class JambaPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class JetMoeForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class JetMoeForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class JetMoeModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class JetMoePreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class JukeboxModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/jetmoe/__init__.py b/tests/models/jetmoe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py new file mode 100644 index 0000000000..12e5dd682c --- /dev/null +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -0,0 +1,536 @@ +# coding=utf-8 +# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch JetMoe model.""" + +import gc +import tempfile +import unittest + +import pytest +from parameterized import parameterized + +from transformers import AutoTokenizer, JetMoeConfig, is_torch_available +from transformers.testing_utils import ( + backend_empty_cache, + is_flaky, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + JetMoeForCausalLM, + JetMoeForSequenceClassification, + JetMoeModel, + ) + + +class JetMoeModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_key_value_heads=2, + kv_channels=8, + intermediate_size=37, + hidden_act="silu", + num_local_experts=4, + num_experts_per_tok=2, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.kv_channels = kv_channels + self.num_attention_heads = num_key_value_heads * num_experts_per_tok + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.ones(self.batch_size, self.seq_length).to(torch_device) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return JetMoeConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_key_value_heads=self.num_key_value_heads, + kv_channels=self.kv_channels, + intermediate_size=self.intermediate_size, + activation_function=self.hidden_act, + num_local_experts=self.num_local_experts, + num_experts_per_tok=self.num_experts_per_tok, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = JetMoeModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = JetMoeModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = JetMoeForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = JetMoeForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + (JetMoeModel, JetMoeForCausalLM, JetMoeForSequenceClassification) if is_torch_available() else () + ) + all_generative_model_classes = (JetMoeForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": JetMoeModel, + "text-classification": JetMoeForSequenceClassification, + "text-generation": JetMoeForCausalLM, + "zero-shot": JetMoeForSequenceClassification, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + test_mismatched_shapes = False + test_cpu_offload = False + test_disk_offload_bin = False + test_disk_offload_safetensors = False + + # TODO: @Fxmarty + @is_flaky(max_attempts=3, description="flaky on some models.") + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + super().test_eager_matches_sdpa_generate() + + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + def setUp(self): + self.model_tester = JetMoeModelTester(self) + self.config_tester = ConfigTester( + self, config_class=JetMoeConfig, common_properties=["hidden_size", "num_hidden_layers"] + ) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config + def test_config(self): + self.config_tester.run_common_tests() + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with llama->jetmoe, Llama->JetMoe + def test_jetmoe_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = JetMoeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with llama->jetmoe, Llama->JetMoe + def test_jetmoe_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = JetMoeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with llama->jetmoe, Llama->JetMoe + def test_jetmoe_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = JetMoeForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + @unittest.skip("JetMoe buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip("JetMoe uses MoA on all models so the KV cache is a non standard format") + def test_past_key_values_format(self): + pass + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: JetMoe apparently does not support right padding + use_cache with FA2. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest("JetMoe flash attention does not support right padding") + + +@require_torch +class JetMoeIntegrationTest(unittest.TestCase): + @slow + def test_model_8b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto") + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + with torch.no_grad(): + out = model(input_ids).logits.cpu() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[0.2507, -2.7073, -1.3445, -1.9363, -1.7216, -1.7370, -1.9054, -1.9792]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([-3.3689, 5.9006, 5.7450, -1.7012, -4.7072, -4.7071, -4.7071, -4.7071, -4.7072, -4.7072, -4.7072, -4.7071, 3.8321, 9.1746, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071]) # fmt: skip + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_model_8b_generation(self): + EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love""" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False) + model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=10, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + def test_model_8b_batched_generation(self): + EXPECTED_TEXT_COMPLETION = [ + """My favourite condiment is ....\nI love ketchup. I love""", + """My favourite 2018 Christmas present was a new pair""", + ] + prompt = [ + "My favourite condiment is ", + "My favourite ", + ] + tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False) + model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto") + input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device) + print(input_ids) + + # greedy generation outputs + generated_ids = model.generate(**input_ids, max_new_tokens=10, temperature=0) + text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + print(text) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + del model + backend_empty_cache(torch_device) + gc.collect()