mlp_only_layers is more flexible than decoder_sparse_step (#30552)
* force back to commit ba40a21 and fix workflow errors * match the review suggestions * fix ci errors * fix CI * fix ci, format code * fix ci, ruff format * fix ci, ruff format again * Update src/transformers/models/qwen2_moe/configuration_qwen2_moe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/qwen2_moe/configuration_qwen2_moe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/qwen2_moe/configuration_qwen2_moe.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * solve this warning: Default Argument Value is mutable --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
73fcfb2861
commit
1c52cb7b3b
|
@ -91,6 +91,10 @@ class Qwen2MoeConfig(PretrainedConfig):
|
|||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
|
||||
|
@ -135,6 +139,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
|||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -164,6 +169,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
|||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Qwen2MoE model."""
|
||||
"""PyTorch Qwen2MoE model."""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import warnings
|
||||
|
@ -861,7 +862,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||
|
||||
self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
||||
|
||||
if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0:
|
||||
if (layer_idx not in config.mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(config)
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
|
||||
|
|
Loading…
Reference in New Issue