@@ -189,98 +189,48 @@ def _compare_buckets(ref_buckets, test_buckets, atol=1e-6):
189189# tests
190190# ---------------------------------------------------------------------------
191191@pytest .mark .parametrize (
192- "n_tokens, n_expts_act, n_expts_tot" ,
192+ "n_tokens, n_expts_act, n_expts_tot, n_expts_global " ,
193193 [
194- # V4-Flash decode shapes (E=256, K=6).
195- (1 , 6 , 256 ),
196- (16 , 6 , 256 ),
197- (64 , 6 , 256 ),
198- (256 , 6 , 256 ),
194+ # V4-Flash decode shapes (E=256, K=6). n_expts_global ignored when
195+ # has_expert_map=False.
196+ (1 , 6 , 256 , 256 ),
197+ (16 , 6 , 256 , 256 ),
198+ (64 , 6 , 256 , 256 ),
199+ (256 , 6 , 256 , 256 ),
199200 # Generic decode shapes used by other MoE configs.
200- (1 , 8 , 384 ),
201- (4 , 8 , 384 ),
202- (64 , 8 , 384 ),
203- (256 , 8 , 384 ),
201+ (1 , 8 , 384 , 384 ),
202+ (4 , 8 , 384 , 384 ),
203+ (64 , 8 , 384 , 384 ),
204+ (256 , 8 , 384 , 384 ),
204205 # Edge: small E.
205- (32 , 4 , 16 ),
206+ (32 , 4 , 16 , 16 ),
206207 # Boundary: NK at the kernel's MAX_NK = 4096.
207- (512 , 8 , 384 ),
208- ],
209- )
210- @pytest .mark .parametrize ("dtype" , [torch .float32 ])
211- def test_fused_routing_from_topk (n_tokens , n_expts_act , n_expts_tot , dtype ):
212- if not torch .cuda .is_available ():
213- pytest .skip ("CUDA not available" )
214- torch .manual_seed (0 )
215- topk_ids , topk_weights = _make_inputs (
216- n_tokens , n_expts_act , n_expts_tot , dtype , DEVICE , seed = 0
217- )
218-
219- ref_hist , ref_topk_indx , ref_gate_indx , ref_gate_scal = routing_from_topk_reference (
220- topk_weights , topk_ids , n_expts_tot
221- )
222- _check_routing_invariants (
223- ref_hist ,
224- ref_topk_indx ,
225- ref_gate_indx ,
226- ref_gate_scal ,
227- topk_ids ,
228- n_expts_tot ,
229- bucket_unsorted_layout = False , # ref uses per-row-sorted layout
230- )
231- ground_buckets = _ground_truth_buckets (topk_ids , topk_weights )
232- ref_buckets = _per_expert_triples (
233- ref_hist , ref_topk_indx , ref_gate_scal , n_expts_act
234- )
235- _compare_buckets (ground_buckets , ref_buckets )
236-
237- test_hist , test_topk_indx , test_gate_indx , test_gate_scal = fused_routing_from_topk (
238- topk_weights , topk_ids , n_expts_tot
239- )
240- _check_routing_invariants (
241- test_hist ,
242- test_topk_indx ,
243- test_gate_indx ,
244- test_gate_scal ,
245- topk_ids ,
246- n_expts_tot ,
247- bucket_unsorted_layout = True , # fused uses unsorted topk_ids layout
248- )
249-
250- # hist must match the reference exactly.
251- assert torch .equal (
252- ref_hist , test_hist
253- ), f"hist mismatch:\n ref={ ref_hist } \n fused={ test_hist } "
254-
255- # Per-expert (token, weight) multisets match the reference.
256- test_buckets = _per_expert_triples (
257- test_hist , test_topk_indx , test_gate_scal , n_expts_act
258- )
259- _compare_buckets (ref_buckets , test_buckets )
260-
261-
262- @pytest .mark .parametrize (
263- "n_tokens, n_expts_act, n_expts_tot,n_expts_global" ,
264- [
208+ (512 , 8 , 384 , 384 ),
209+ # Expert-parallel shapes: n_expts_global > n_expts_tot, requires map.
265210 (16 , 6 , 64 , 256 ),
266211 (64 , 6 , 128 , 256 ),
267212 ],
268213)
214+ @pytest .mark .parametrize ("has_expert_map" , [False , True ])
269215@pytest .mark .parametrize ("dtype" , [torch .float32 ])
270- def test_fused_routing_from_topk_with_expert_map (
271- n_tokens , n_expts_act , n_expts_tot , n_expts_global , dtype
216+ def test_fused_routing_from_topk (
217+ n_tokens , n_expts_act , n_expts_tot , n_expts_global , has_expert_map , dtype
272218):
273219 if not torch .cuda .is_available ():
274220 pytest .skip ("CUDA not available" )
275221 torch .manual_seed (0 )
222+
223+ id_range = n_expts_global if has_expert_map else n_expts_tot
276224 topk_ids , topk_weights = _make_inputs (
277- n_tokens , n_expts_act , n_expts_global , dtype , DEVICE , seed = 0
225+ n_tokens , n_expts_act , id_range , dtype , DEVICE , seed = 0
278226 )
279227
280- expert_map = torch .full ((n_expts_global ,), - 1 , dtype = torch .int32 , device = DEVICE )
281- expert_map [: n_expts_tot // 2 ] = torch .arange (
282- n_expts_tot // 2 , dtype = torch .int32 , device = DEVICE
283- )
228+ expert_map = None
229+ if has_expert_map :
230+ expert_map = torch .full ((n_expts_global ,), - 1 , dtype = torch .int32 , device = DEVICE )
231+ expert_map [: n_expts_tot // 2 ] = torch .arange (
232+ n_expts_tot // 2 , dtype = torch .int32 , device = DEVICE
233+ )
284234
285235 ref_hist , ref_topk_indx , ref_gate_indx , ref_gate_scal = routing_from_topk_reference (
286236 topk_weights , topk_ids , n_expts_tot , expert_map = expert_map
@@ -292,7 +242,7 @@ def test_fused_routing_from_topk_with_expert_map(
292242 ref_gate_scal ,
293243 topk_ids ,
294244 n_expts_tot ,
295- bucket_unsorted_layout = False ,
245+ bucket_unsorted_layout = False , # ref uses per-row-sorted layout
296246 )
297247
298248 test_hist , test_topk_indx , test_gate_indx , test_gate_scal = fused_routing_from_topk (
@@ -305,18 +255,33 @@ def test_fused_routing_from_topk_with_expert_map(
305255 test_gate_scal ,
306256 topk_ids ,
307257 n_expts_tot ,
308- bucket_unsorted_layout = False ,
258+ bucket_unsorted_layout = not has_expert_map ,
309259 )
310260
261+ # hist must match the reference exactly.
311262 assert torch .equal (
312263 ref_hist , test_hist
313264 ), f"hist mismatch:\n ref={ ref_hist } \n fused={ test_hist } "
314265
315- # Intra-expert ordering can differ between fused and reference,
316- # especially in expert-0 bucket where invalid experts are redirected.
317- # Compare zeroed-weight cardinality instead of elementwise positions.
318- ref_zero_count = int ((ref_gate_scal == 0 ).sum ().item ())
319- test_zero_count = int ((test_gate_scal == 0 ).sum ().item ())
320- assert (
321- ref_zero_count == test_zero_count
322- ), f"zero-masked count mismatch: ref={ ref_zero_count } , fused={ test_zero_count } "
266+ if has_expert_map :
267+ # Intra-expert ordering can differ between fused and reference,
268+ # especially in expert-0 bucket where invalid experts are redirected.
269+ # Compare zeroed-weight cardinality instead of elementwise positions.
270+ ref_zero_count = int ((ref_gate_scal == 0 ).sum ().item ())
271+ test_zero_count = int ((test_gate_scal == 0 ).sum ().item ())
272+ assert ref_zero_count == test_zero_count , (
273+ f"zero-masked count mismatch: "
274+ f"ref={ ref_zero_count } , fused={ test_zero_count } "
275+ )
276+ else :
277+ ground_buckets = _ground_truth_buckets (topk_ids , topk_weights )
278+ ref_buckets = _per_expert_triples (
279+ ref_hist , ref_topk_indx , ref_gate_scal , n_expts_act
280+ )
281+ _compare_buckets (ground_buckets , ref_buckets )
282+
283+ # Per-expert (token, weight) multisets match the reference.
284+ test_buckets = _per_expert_triples (
285+ test_hist , test_topk_indx , test_gate_scal , n_expts_act
286+ )
287+ _compare_buckets (ref_buckets , test_buckets )
0 commit comments