Skip to content

Commit 8bf8759

Browse files
TimDettmersclaude
andcommitted
Update MoE SM_100 tests for init/run split pipeline
- Add weighted gather kernel existence test - Add tile size selection tests (small M < 512 vs large M >= 512) - Add init/run caching test (repeated calls with same dimensions) - Rename kernel launch count test to stream compatibility test - Simplify NaN diagnosis test (remove alpha=1.0 fallback check) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 90aecac commit 8bf8759

File tree

1 file changed

+80
-41
lines changed

1 file changed

+80
-41
lines changed

tests/test_moe_sm100_pipeline.py

Lines changed: 80 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
"""Tests for SM_100 (B200) NVFP4 MoE 6-kernel pipeline.
1+
"""Tests for SM_100 (B200) NVFP4 MoE pipeline with init/run split.
22
33
Requires a B200 GPU (compute capability 10.0).
44
Tests:
5-
1. Build verification (CUTLASS kernels compile and load)
5+
1. Build verification (CUTLASS kernels compile and load, including weighted gather)
66
2. Individual kernel correctness (scatter, gather, quantize_raw, scale_to_blocked_batched)
77
3. Full MoE pipeline correctness (compare against reference implementation)
8-
4. Kernel launch count verification
8+
4. Tile size selection (small M vs large M)
9+
5. Init/run caching behavior
910
"""
1011

1112
import pytest
@@ -92,6 +93,11 @@ def test_scale_to_blocked_batched_exists(self):
9293
assert hasattr(lib, "cscale_to_blocked_batched"), \
9394
"Batched scale swizzle kernel not found"
9495

96+
def test_weighted_gather_exists(self):
97+
from bitsandbytes.cextension import lib
98+
assert hasattr(lib, "cmoe_weighted_gather_bf16"), \
99+
"Weighted gather kernel not found — moe_scatter_gather.cu not updated"
100+
95101
def test_fused_quantize_exists(self):
96102
from bitsandbytes.cextension import lib
97103
assert hasattr(lib, "cfused_quantize_nvfp4_quest"), \
@@ -331,40 +337,28 @@ def test_pipeline_nan_diagnosis(self, small_moe_config):
331337

332338
# Step 2: quantize
333339
global_scale = (1.0 / act_scale).to(torch.float32)
334-
print(f" Step 2 (global_scale = 1/abs_max): {global_scale.item():.8f}")
335340
packed_all, scales_all = quantize_nvfp4_raw(x_2d, global_scale)
336-
print(f" Step 2 (quantize): packed={packed_all.shape}, scales={scales_all.shape}, "
337-
f"packed_nonzero={packed_all.count_nonzero().item()}/{packed_all.numel()}")
341+
print(f" Step 2 (quantize): packed={packed_all.shape}, scales={scales_all.shape}")
338342

339343
# Step 3: scatter
340344
packed_batched = moe_scatter_nvfp4(packed_all, expert_offsets_i32, max_M, K, num_experts)
341-
print(f" Step 3 (scatter): shape={packed_batched.shape}, "
342-
f"nonzero={packed_batched.count_nonzero().item()}/{packed_batched.numel()}")
345+
print(f" Step 3 (scatter): shape={packed_batched.shape}")
343346

344347
# Step 4: swizzle scales
345348
sfa_batched = scale_to_blocked_batched(scales_all, expert_offsets_i32, max_M, K, num_experts)
346-
print(f" Step 4 (swizzle): shape={sfa_batched.shape}, "
347-
f"nonzero={sfa_batched.count_nonzero().item()}/{sfa_batched.numel()}")
349+
print(f" Step 4 (swizzle): shape={sfa_batched.shape}")
348350

349-
# Step 5: GEMM
351+
# Step 5: GEMM (uses init/run split internally)
350352
alpha_dev = (act_scale * layer.weight_tensor_scale).to(torch.float32)
351-
print(f" Step 5 (alpha): {alpha_dev.item():.6f}, "
352-
f"weight_tensor_scale={layer.weight_tensor_scale:.6f}")
353-
print(f" Step 5 (weight_packed): shape={layer.weight_packed.shape}, "
354-
f"nonzero={layer.weight_packed.count_nonzero().item()}/{layer.weight_packed.numel()}")
355-
print(f" Step 5 (weight_scales_batched): shape={layer.weight_scales_batched.shape}, "
356-
f"nonzero={layer.weight_scales_batched.count_nonzero().item()}/{layer.weight_scales_batched.numel()}")
357-
358353
D = gemm_nvfp4_moe(
359354
packed_batched, sfa_batched, alpha_dev,
360355
layer.weight_packed, layer.weight_scales_batched,
361356
max_M, N, K, num_experts,
362357
)
363358
torch.cuda.synchronize()
364359
nan_D = torch.isnan(D).sum().item()
365-
inf_D = torch.isinf(D).sum().item()
366360
print(f" Step 5 (GEMM out): shape={D.shape}, nan={nan_D}/{D.numel()}, "
367-
f"inf={inf_D}, abs_max={D[~torch.isnan(D)].abs().max().item() if nan_D < D.numel() else 'all_nan'}")
361+
f"abs_max={D[~torch.isnan(D)].abs().max().item() if nan_D < D.numel() else 'all_nan'}")
368362

369363
# Step 6: gather
370364
D_flat = D.view(-1).contiguous()
@@ -373,22 +367,9 @@ def test_pipeline_nan_diagnosis(self, small_moe_config):
373367
nan_out = torch.isnan(out).sum().item()
374368
print(f" Step 6 (gather): shape={out.shape}, nan={nan_out}/{out.numel()}")
375369

376-
# Try with scalar alpha=1.0 to isolate alpha_ptr issue
377-
alpha_one = torch.tensor([1.0], dtype=torch.float32, device="cuda")
378-
D2 = gemm_nvfp4_moe(
379-
packed_batched, sfa_batched, alpha_one,
380-
layer.weight_packed, layer.weight_scales_batched,
381-
max_M, N, K, num_experts,
382-
)
383-
torch.cuda.synchronize()
384-
nan_D2 = torch.isnan(D2).sum().item()
385-
print(f" GEMM with alpha=1.0: nan={nan_D2}/{D2.numel()}, "
386-
f"abs_max={D2[~torch.isnan(D2)].abs().max().item() if nan_D2 < D2.numel() else 'all_nan'}")
387-
388370
assert nan_D == 0, \
389371
f"GEMM output has {nan_D}/{D.numel()} NaN elements"
390372

391-
# Verify output is non-zero (weights should be non-zero after init)
392373
assert D.abs().max().item() > 0, \
393374
f"GEMM output is all zeros despite non-zero weights"
394375

@@ -509,17 +490,75 @@ def test_device_alpha_produces_output(self, small_moe_config):
509490
assert alpha_dev.dtype == torch.float32
510491

511492

512-
class TestKernelLaunchCount:
513-
"""Verify the pipeline uses the expected number of kernel launches."""
493+
class TestTileSelection:
494+
"""Test that the two tile sizes work correctly for different M values."""
514495

515-
def test_no_item_in_compute_path(self, small_moe_config):
516-
"""Verify no .item() calls happen in the compute pipeline.
496+
def test_small_m_uses_small_tile(self):
497+
"""M < 512 should trigger the small tile (128x128x256)."""
498+
K = 256
499+
N = 512
500+
num_experts = 4
501+
# 4 tokens per expert → max_M = 128 → small tile
502+
tpe = [4, 8, 2, 6]
503+
total_tokens = sum(tpe)
517504

518-
We test this by checking that the pipeline can run entirely
519-
within a CUDA stream without explicit synchronization.
520-
"""
521-
from bitsandbytes.nn.modules import LinearNVFP4MoE
505+
layer = _make_moe_layer(num_experts, K, N, bias=False)
506+
x = torch.randn(total_tokens, K, dtype=torch.bfloat16, device="cuda")
507+
expert_offsets = _make_expert_offsets(tpe)
508+
509+
out = layer(x, expert_offsets)
510+
assert out.shape == (total_tokens, N)
511+
assert not torch.isnan(out).any(), "Small tile output has NaN"
522512

513+
def test_large_m_uses_large_tile(self):
514+
"""M >= 512 should trigger the large tile (128x256x256)."""
515+
K = 256
516+
N = 512
517+
num_experts = 2
518+
# 512 tokens per expert → max_M = 512 → large tile
519+
tpe = [512, 256]
520+
total_tokens = sum(tpe)
521+
522+
layer = _make_moe_layer(num_experts, K, N, bias=False)
523+
x = torch.randn(total_tokens, K, dtype=torch.bfloat16, device="cuda")
524+
expert_offsets = _make_expert_offsets(tpe)
525+
526+
out = layer(x, expert_offsets)
527+
assert out.shape == (total_tokens, N)
528+
assert not torch.isnan(out).any(), "Large tile output has NaN"
529+
530+
531+
class TestInitRunCaching:
532+
"""Test that the init/run split caches correctly."""
533+
534+
def test_repeated_calls_same_dims(self, small_moe_config):
535+
"""Multiple calls with same dimensions should reuse cached init."""
536+
K = small_moe_config["input_features"]
537+
N = small_moe_config["output_features"]
538+
num_experts = small_moe_config["num_experts"]
539+
tpe = small_moe_config["tokens_per_expert"]
540+
total_tokens = sum(tpe)
541+
542+
layer = _make_moe_layer(num_experts, K, N, bias=False)
543+
expert_offsets = _make_expert_offsets(tpe)
544+
545+
results = []
546+
for _ in range(3):
547+
x = torch.randn(total_tokens, K, dtype=torch.bfloat16, device="cuda")
548+
out = layer(x, expert_offsets)
549+
results.append(out.clone())
550+
551+
# All outputs should be valid (no NaN from stale pointers)
552+
for i, r in enumerate(results):
553+
assert not torch.isnan(r).any(), f"Call {i} produced NaN"
554+
assert r.abs().max().item() > 0, f"Call {i} produced all zeros"
555+
556+
557+
class TestStreamCompatibility:
558+
"""Verify the pipeline works on non-default streams."""
559+
560+
def test_no_item_in_compute_path(self, small_moe_config):
561+
"""Verify the pipeline can run entirely within a CUDA stream."""
523562
K = small_moe_config["input_features"]
524563
N = small_moe_config["output_features"]
525564
num_experts = small_moe_config["num_experts"]

0 commit comments

Comments
 (0)