@@ -414,17 +414,20 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f
414414@pytest .mark .parametrize ("num_tokens" , [2048 , 7168 , 14234 ])
415415@pytest .mark .parametrize ("num_experts" , [1024 , 256 , 128 , 32 ])
416416@pytest .mark .parametrize ("topk" , [4 , 32 ])
417- def test_fused_moe_aux_loss (dtype , num_tokens , num_experts , topk ):
417+ @pytest .mark .parametrize ("expert_multiplier" , [1 , 2 ])
418+ def test_fused_moe_aux_loss (dtype , num_tokens , num_experts , topk , expert_multiplier ):
418419 if topk >= num_experts :
419420 pytest .skip (f"topk ({ topk } ) >= num_experts ({ num_experts } )" )
421+ # Sequence aux loss batches independent sequences along the expert dimension.
422+ num_cols = num_experts * expert_multiplier
420423 # Construct the special probs to avoid inf in the sigmoid function
421424 offset = torch .arange (- num_tokens // 2 , num_tokens // 2 , dtype = dtype , device = "cuda" ) * 1e-4
422- probs = torch .arange (- num_experts // 2 , num_experts // 2 , device = "cuda" , dtype = dtype ) * 1e-2
425+ probs = torch .arange (- num_cols // 2 , num_cols // 2 , device = "cuda" , dtype = dtype ) * 1e-2
423426 probs = probs .unsqueeze (0 ).repeat (num_tokens , 1 ) + offset .unsqueeze (1 )
424- probs = probs .view (num_tokens , num_experts )
427+ probs = probs .view (num_tokens , num_cols )
425428 probs .requires_grad = True
426429
427- tokens_per_expert = torch .randint (1 , 1000 , (num_experts ,), device = "cuda" , dtype = torch .int32 )
430+ tokens_per_expert = torch .randint (1 , 1000 , (num_cols ,), device = "cuda" , dtype = torch .int32 )
428431 coeff = 0.01
429432
430433 probs_clone = deepcopy (probs )
0 commit comments