Skip to content

Commit 00fdc39

Browse files
TimDettmersclaude
andcommitted
Fix test ordering: check NaN explicitly, access state after forward
- test_pipeline_deterministic: check for NaN before comparing outputs - test_pipeline_larger_config: access weight_scales_batched after forward Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 52076fc commit 00fdc39

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

tests/test_moe_sm100_pipeline.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def test_pipeline_deterministic(self, small_moe_config):
296296
297297
Note: CUTLASS SM_100 block-scaled GEMM may have non-deterministic
298298
accumulation order across tiles, so we use approximate comparison.
299+
NaN in output indicates a GEMM state issue (tracked separately).
299300
"""
300301
from bitsandbytes.nn.modules import LinearNVFP4MoE
301302

@@ -312,7 +313,13 @@ def test_pipeline_deterministic(self, small_moe_config):
312313
expert_offsets = _make_expert_offsets(tpe)
313314

314315
out1 = layer(x, expert_offsets)
316+
has_nan1 = torch.isnan(out1).any().item()
317+
assert not has_nan1, f"First call produced NaN (abs_max={out1.abs().max().item()})"
318+
315319
out2 = layer(x, expert_offsets)
320+
has_nan2 = torch.isnan(out2).any().item()
321+
assert not has_nan2, \
322+
f"Second call produced NaN (first call abs_max={out1.abs().max().item()})"
316323

317324
# Allow small numerical differences from non-deterministic accumulation
318325
if not torch.equal(out1, out2):
@@ -348,15 +355,18 @@ def test_pipeline_larger_config(self, moe_config):
348355
layer = LinearNVFP4MoE(num_experts, K, N, bias=False)
349356
layer = layer.cuda()
350357

351-
# Diagnostic: check weight_scales_batched size
352-
actual_sfb = layer.weight_scales_batched.numel()
353-
print(f" weight_scales_batched size: {actual_sfb} bytes, expected batched: {sfb_batched}")
354-
355358
x = torch.randn(total_tokens, K, dtype=torch.bfloat16, device="cuda")
356359
expert_offsets = _make_expert_offsets(tpe)
357360

358361
out = layer(x, expert_offsets)
359362

363+
# Diagnostic: check weight_scales_batched size (after forward triggers quantization)
364+
if layer.weight_scales_batched is not None:
365+
actual_sfb = layer.weight_scales_batched.numel()
366+
print(f" weight_scales_batched size: {actual_sfb} bytes, expected batched: {sfb_batched}")
367+
else:
368+
print(" WARNING: weight_scales_batched is None after forward")
369+
360370
assert out.shape == (total_tokens, N)
361371
assert out.dtype == torch.bfloat16
362372

0 commit comments

Comments
 (0)