Fix load balancing loss func for mixtral (#28256)

* Correct the implementation of auxiliary loss of mixtrtal

* correct the implementation of auxiliary loss of mixtrtal

* Implement a simpler calculation method

---------

Co-authored-by: zhangliangxu3 <zhangliangxu3@jd.com>
This commit is contained in:
liangxuZhang 2024-01-11 23:16:12 +08:00 committed by GitHub
parent 5d4d62d0a2
commit e768616afa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 6 deletions

View File

@ -103,11 +103,7 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
# treat `top_k` as tokens (shape is `top_k X [batch_size X sequence_length]`)
selected_experts = selected_experts.reshape(-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
expert_mask = torch.max(expert_mask, dim=-2).values
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
@ -115,7 +111,7 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-1))
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts

View File

@ -474,7 +474,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.eval()
result = model(input_ids, attention_mask=attention_mask)
self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(8, dtype=torch.float32))
torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2)
@require_torch