1- """Tests for SM_100 (B200) NVFP4 MoE 6-kernel pipeline.
1+ """Tests for SM_100 (B200) NVFP4 MoE pipeline with init/run split .
22
33Requires a B200 GPU (compute capability 10.0).
44Tests:
5- 1. Build verification (CUTLASS kernels compile and load)
5+ 1. Build verification (CUTLASS kernels compile and load, including weighted gather )
662. Individual kernel correctness (scatter, gather, quantize_raw, scale_to_blocked_batched)
773. Full MoE pipeline correctness (compare against reference implementation)
8- 4. Kernel launch count verification
8+ 4. Tile size selection (small M vs large M)
9+ 5. Init/run caching behavior
910"""
1011
1112import pytest
@@ -92,6 +93,11 @@ def test_scale_to_blocked_batched_exists(self):
9293 assert hasattr (lib , "cscale_to_blocked_batched" ), \
9394 "Batched scale swizzle kernel not found"
9495
96+ def test_weighted_gather_exists (self ):
97+ from bitsandbytes .cextension import lib
98+ assert hasattr (lib , "cmoe_weighted_gather_bf16" ), \
99+ "Weighted gather kernel not found — moe_scatter_gather.cu not updated"
100+
95101 def test_fused_quantize_exists (self ):
96102 from bitsandbytes .cextension import lib
97103 assert hasattr (lib , "cfused_quantize_nvfp4_quest" ), \
@@ -331,40 +337,28 @@ def test_pipeline_nan_diagnosis(self, small_moe_config):
331337
332338 # Step 2: quantize
333339 global_scale = (1.0 / act_scale ).to (torch .float32 )
334- print (f" Step 2 (global_scale = 1/abs_max): { global_scale .item ():.8f} " )
335340 packed_all , scales_all = quantize_nvfp4_raw (x_2d , global_scale )
336- print (f" Step 2 (quantize): packed={ packed_all .shape } , scales={ scales_all .shape } , "
337- f"packed_nonzero={ packed_all .count_nonzero ().item ()} /{ packed_all .numel ()} " )
341+ print (f" Step 2 (quantize): packed={ packed_all .shape } , scales={ scales_all .shape } " )
338342
339343 # Step 3: scatter
340344 packed_batched = moe_scatter_nvfp4 (packed_all , expert_offsets_i32 , max_M , K , num_experts )
341- print (f" Step 3 (scatter): shape={ packed_batched .shape } , "
342- f"nonzero={ packed_batched .count_nonzero ().item ()} /{ packed_batched .numel ()} " )
345+ print (f" Step 3 (scatter): shape={ packed_batched .shape } " )
343346
344347 # Step 4: swizzle scales
345348 sfa_batched = scale_to_blocked_batched (scales_all , expert_offsets_i32 , max_M , K , num_experts )
346- print (f" Step 4 (swizzle): shape={ sfa_batched .shape } , "
347- f"nonzero={ sfa_batched .count_nonzero ().item ()} /{ sfa_batched .numel ()} " )
349+ print (f" Step 4 (swizzle): shape={ sfa_batched .shape } " )
348350
349- # Step 5: GEMM
351+ # Step 5: GEMM (uses init/run split internally)
350352 alpha_dev = (act_scale * layer .weight_tensor_scale ).to (torch .float32 )
351- print (f" Step 5 (alpha): { alpha_dev .item ():.6f} , "
352- f"weight_tensor_scale={ layer .weight_tensor_scale :.6f} " )
353- print (f" Step 5 (weight_packed): shape={ layer .weight_packed .shape } , "
354- f"nonzero={ layer .weight_packed .count_nonzero ().item ()} /{ layer .weight_packed .numel ()} " )
355- print (f" Step 5 (weight_scales_batched): shape={ layer .weight_scales_batched .shape } , "
356- f"nonzero={ layer .weight_scales_batched .count_nonzero ().item ()} /{ layer .weight_scales_batched .numel ()} " )
357-
358353 D = gemm_nvfp4_moe (
359354 packed_batched , sfa_batched , alpha_dev ,
360355 layer .weight_packed , layer .weight_scales_batched ,
361356 max_M , N , K , num_experts ,
362357 )
363358 torch .cuda .synchronize ()
364359 nan_D = torch .isnan (D ).sum ().item ()
365- inf_D = torch .isinf (D ).sum ().item ()
366360 print (f" Step 5 (GEMM out): shape={ D .shape } , nan={ nan_D } /{ D .numel ()} , "
367- f"inf= { inf_D } , abs_max={ D [~ torch .isnan (D )].abs ().max ().item () if nan_D < D .numel () else 'all_nan' } " )
361+ f"abs_max={ D [~ torch .isnan (D )].abs ().max ().item () if nan_D < D .numel () else 'all_nan' } " )
368362
369363 # Step 6: gather
370364 D_flat = D .view (- 1 ).contiguous ()
@@ -373,22 +367,9 @@ def test_pipeline_nan_diagnosis(self, small_moe_config):
373367 nan_out = torch .isnan (out ).sum ().item ()
374368 print (f" Step 6 (gather): shape={ out .shape } , nan={ nan_out } /{ out .numel ()} " )
375369
376- # Try with scalar alpha=1.0 to isolate alpha_ptr issue
377- alpha_one = torch .tensor ([1.0 ], dtype = torch .float32 , device = "cuda" )
378- D2 = gemm_nvfp4_moe (
379- packed_batched , sfa_batched , alpha_one ,
380- layer .weight_packed , layer .weight_scales_batched ,
381- max_M , N , K , num_experts ,
382- )
383- torch .cuda .synchronize ()
384- nan_D2 = torch .isnan (D2 ).sum ().item ()
385- print (f" GEMM with alpha=1.0: nan={ nan_D2 } /{ D2 .numel ()} , "
386- f"abs_max={ D2 [~ torch .isnan (D2 )].abs ().max ().item () if nan_D2 < D2 .numel () else 'all_nan' } " )
387-
388370 assert nan_D == 0 , \
389371 f"GEMM output has { nan_D } /{ D .numel ()} NaN elements"
390372
391- # Verify output is non-zero (weights should be non-zero after init)
392373 assert D .abs ().max ().item () > 0 , \
393374 f"GEMM output is all zeros despite non-zero weights"
394375
@@ -509,17 +490,75 @@ def test_device_alpha_produces_output(self, small_moe_config):
509490 assert alpha_dev .dtype == torch .float32
510491
511492
512- class TestKernelLaunchCount :
513- """Verify the pipeline uses the expected number of kernel launches ."""
493+ class TestTileSelection :
494+ """Test that the two tile sizes work correctly for different M values ."""
514495
515- def test_no_item_in_compute_path (self , small_moe_config ):
516- """Verify no .item() calls happen in the compute pipeline.
496+ def test_small_m_uses_small_tile (self ):
497+ """M < 512 should trigger the small tile (128x128x256)."""
498+ K = 256
499+ N = 512
500+ num_experts = 4
501+ # 4 tokens per expert → max_M = 128 → small tile
502+ tpe = [4 , 8 , 2 , 6 ]
503+ total_tokens = sum (tpe )
517504
518- We test this by checking that the pipeline can run entirely
519- within a CUDA stream without explicit synchronization.
520- """
521- from bitsandbytes .nn .modules import LinearNVFP4MoE
505+ layer = _make_moe_layer (num_experts , K , N , bias = False )
506+ x = torch .randn (total_tokens , K , dtype = torch .bfloat16 , device = "cuda" )
507+ expert_offsets = _make_expert_offsets (tpe )
508+
509+ out = layer (x , expert_offsets )
510+ assert out .shape == (total_tokens , N )
511+ assert not torch .isnan (out ).any (), "Small tile output has NaN"
522512
513+ def test_large_m_uses_large_tile (self ):
514+ """M >= 512 should trigger the large tile (128x256x256)."""
515+ K = 256
516+ N = 512
517+ num_experts = 2
518+ # 512 tokens per expert → max_M = 512 → large tile
519+ tpe = [512 , 256 ]
520+ total_tokens = sum (tpe )
521+
522+ layer = _make_moe_layer (num_experts , K , N , bias = False )
523+ x = torch .randn (total_tokens , K , dtype = torch .bfloat16 , device = "cuda" )
524+ expert_offsets = _make_expert_offsets (tpe )
525+
526+ out = layer (x , expert_offsets )
527+ assert out .shape == (total_tokens , N )
528+ assert not torch .isnan (out ).any (), "Large tile output has NaN"
529+
530+
531+ class TestInitRunCaching :
532+ """Test that the init/run split caches correctly."""
533+
534+ def test_repeated_calls_same_dims (self , small_moe_config ):
535+ """Multiple calls with same dimensions should reuse cached init."""
536+ K = small_moe_config ["input_features" ]
537+ N = small_moe_config ["output_features" ]
538+ num_experts = small_moe_config ["num_experts" ]
539+ tpe = small_moe_config ["tokens_per_expert" ]
540+ total_tokens = sum (tpe )
541+
542+ layer = _make_moe_layer (num_experts , K , N , bias = False )
543+ expert_offsets = _make_expert_offsets (tpe )
544+
545+ results = []
546+ for _ in range (3 ):
547+ x = torch .randn (total_tokens , K , dtype = torch .bfloat16 , device = "cuda" )
548+ out = layer (x , expert_offsets )
549+ results .append (out .clone ())
550+
551+ # All outputs should be valid (no NaN from stale pointers)
552+ for i , r in enumerate (results ):
553+ assert not torch .isnan (r ).any (), f"Call { i } produced NaN"
554+ assert r .abs ().max ().item () > 0 , f"Call { i } produced all zeros"
555+
556+
557+ class TestStreamCompatibility :
558+ """Verify the pipeline works on non-default streams."""
559+
560+ def test_no_item_in_compute_path (self , small_moe_config ):
561+ """Verify the pipeline can run entirely within a CUDA stream."""
523562 K = small_moe_config ["input_features" ]
524563 N = small_moe_config ["output_features" ]
525564 num_experts = small_moe_config ["num_experts" ]
0 commit comments