diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index b16a358f20..942c0712be 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -91,6 +91,10 @@ class Qwen2MoeConfig(PretrainedConfig): allow the model to output the auxiliary loss, including load balancing loss and router z-loss. router_aux_loss_coef (`float`, *optional*, defaults to 0.001): The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. ```python >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig @@ -135,6 +139,7 @@ class Qwen2MoeConfig(PretrainedConfig): norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, + mlp_only_layers=None, **kwargs, ): self.vocab_size = vocab_size @@ -164,6 +169,7 @@ class Qwen2MoeConfig(PretrainedConfig): self.norm_topk_prob = norm_topk_prob self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers super().__init__( tie_word_embeddings=tie_word_embeddings, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index ef1dd23cde..838425505b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -17,7 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Qwen2MoE model.""" +"""PyTorch Qwen2MoE model.""" + import inspect import math import warnings @@ -861,7 +862,9 @@ class Qwen2MoeDecoderLayer(nn.Module): self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0: + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): self.mlp = Qwen2MoeSparseMoeBlock(config) else: self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)