Skip to content

Commit 230efa9

Browse files
committed
[PyTorch] Cover expanded columns in fused MoE aux loss test
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
1 parent 20f1e30 commit 230efa9

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

tests/pytorch/test_fused_router.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,17 +414,20 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
414414
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
415415
@pytest.mark.parametrize("num_experts", [1024, 256, 128, 32])
416416
@pytest.mark.parametrize("topk", [4, 32])
417-
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk):
417+
@pytest.mark.parametrize("expert_multiplier", [1, 2])
418+
def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk, expert_multiplier):
418419
if topk >= num_experts:
419420
pytest.skip(f"topk ({topk}) >= num_experts ({num_experts})")
421+
# Sequence aux loss batches independent sequences along the expert dimension.
422+
num_cols = num_experts * expert_multiplier
420423
# Construct the special probs to avoid inf in the sigmoid function
421424
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
422-
probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
425+
probs = torch.arange(-num_cols // 2, num_cols // 2, device="cuda", dtype=dtype) * 1e-2
423426
probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1)
424-
probs = probs.view(num_tokens, num_experts)
427+
probs = probs.view(num_tokens, num_cols)
425428
probs.requires_grad = True
426429

427-
tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32)
430+
tokens_per_expert = torch.randint(1, 1000, (num_cols,), device="cuda", dtype=torch.int32)
428431
coeff = 0.01
429432

430433
probs_clone = deepcopy(probs)

0 commit comments

Comments
 (0)