Skip to content

Commit a0d6e9b

Browse files
authored
switch correctness checks to SNR-based assertion for cuda quant int4_matmul (#19300)
Replace torch.allclose(atol/rtol) with an SNR (signal-to-noise ratio) assertion across all int4_matmul / int4_matvec / dequant-vs-fused tests. Why: - test_prefill_short was flaking on CI (A10G) with max_abs_err=1.0000. Root cause: bf16 GEMM with K=2048 reduction produces output magnitudes up to ~200; at that scale, the bf16 ULP gap is 0.5-1.0. Triton fused kernel and cuBLAS reduce in different orders (and Triton autotune picks different tile configs on different hardware), so 1-ULP element-wise differences are unavoidable. atol/rtol false-fails on these outliers; SNR averages them out. - atol/rtol thresholds also depend on size: a value tuned for K=2048 is too loose for K=64 and too tight for K=4096. SNR is size-invariant (||signal|| and ||noise|| both scale with sqrt(N) and sqrt(K), canceling in the ratio). What: - Add _assert_snr(test_case, actual, expected, label) helper that asserts 20*log10(||expected|| / ||actual-expected||) >= 50 dB. - Replace 4 call sites: TestInt4Matmul, TestInt4Matvec (x2), TestDequantThenMatmul. - 50 dB ~ 0.3% RMS error: well below observed clean noise (80-90 dB) and well above any real functional bug (<20 dB SNR for wrong stride / flipped nibble / off-by-one group_idx / missing mask). Test plan: python -m pytest backends/cuda/tests/test_int4_matmul.py -v -> 35/35 passed
1 parent ff25a2f commit a0d6e9b

1 file changed

Lines changed: 39 additions & 25 deletions

File tree

backends/cuda/tests/test_int4_matmul.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import unittest
2020

2121
import torch
22-
2322
from executorch.backends.cuda.triton.kernels.int4_matmul import (
2423
dequant_w4_to_bf16,
2524
int4_matmul,
@@ -28,6 +27,41 @@
2827

2928
ATOL = 0.01
3029
DEVICE = "cuda"
30+
SNR_THRESHOLD_DB = 50.0
31+
32+
33+
def _assert_snr(test_case, actual, expected, label, threshold_db=SNR_THRESHOLD_DB):
34+
"""Assert signal-to-noise ratio (in dB) of `actual` vs `expected` >= threshold.
35+
36+
SNR = 20*log10(||expected||_2 / ||actual - expected||_2)
37+
38+
Why SNR rather than torch.allclose(atol/rtol):
39+
* Size-invariant: ||signal|| and ||noise|| both scale with sqrt(N) and
40+
with sqrt(K) (CLT + random-walk rounding), so the ratio is independent
41+
of tensor size and reduction depth. The same threshold works for
42+
K=64 and K=4096, M=1 and M=1024.
43+
* Robust to bf16 ULP outliers: with K=2048 and output magnitudes ~200,
44+
a single element can differ by ~1.0 just from differing reduction
45+
orders (Triton fused vs cuBLAS). atol/rtol false-fails on these;
46+
SNR averages them out.
47+
* Sensitive to real bugs: wrong stride, flipped nibble, off-by-one
48+
group_idx, or a missing mask all collapse SNR to <20 dB. The 50 dB
49+
threshold (≈0.3% RMS error) sits comfortably between observed clean
50+
noise floor (~80-90 dB) and any genuine functional break.
51+
"""
52+
a = actual.float()
53+
b = expected.float()
54+
diff = a - b
55+
signal = b.norm()
56+
noise = diff.norm()
57+
snr_db = (20.0 * torch.log10(signal / noise.clamp(min=1e-9))).item()
58+
test_case.assertGreater(
59+
snr_db,
60+
threshold_db,
61+
f"{label}: SNR={snr_db:.1f} dB (threshold {threshold_db:.1f} dB), "
62+
f"max_abs_err={diff.abs().max().item():.4f}, "
63+
f"signal_norm={signal.item():.2f}, noise_norm={noise.item():.4f}",
64+
)
3165

3266

3367
def _quantize_simple(w_bf16, group_size):
@@ -118,12 +152,7 @@ def _run_matmul(self, M, N, K, group_size):
118152

119153
self.assertEqual(out.shape, (M, N))
120154
self.assertEqual(out.dtype, torch.bfloat16)
121-
self.assertTrue(
122-
torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01),
123-
f"int4_matmul M={M} [{N}x{K}] gs={group_size}: "
124-
f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, "
125-
f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}",
126-
)
155+
_assert_snr(self, out, ref, f"int4_matmul M={M} [{N}x{K}] gs={group_size}")
127156

128157
# --- Decode (M=1) ---
129158
def test_decode_square(self):
@@ -189,13 +218,7 @@ def _run_matvec(self, N, K, group_size):
189218

190219
self.assertEqual(out.shape, (1, N))
191220
self.assertEqual(out.dtype, torch.bfloat16)
192-
# atol=1.0 for large accumulation across K, rtol=0.01 for relative
193-
self.assertTrue(
194-
torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01),
195-
f"int4_matvec [{N}x{K}] gs={group_size}: "
196-
f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, "
197-
f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}",
198-
)
221+
_assert_snr(self, out, ref, f"int4_matvec [{N}x{K}] gs={group_size}")
199222

200223
def test_qkv_proj(self):
201224
self._run_matvec(2048, 2048, 128)
@@ -226,10 +249,7 @@ def test_matches_int4_matmul(self):
226249
out_mv = int4_matvec(x, packed, scale, gs)
227250
out_mm = int4_matmul(x, packed, scale, gs)
228251

229-
self.assertTrue(
230-
torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01),
231-
f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}",
232-
)
252+
_assert_snr(self, out_mv, out_mm, "matvec vs matmul")
233253

234254

235255
class TestDequantThenMatmul(unittest.TestCase):
@@ -248,13 +268,7 @@ def _run(self, M, N, K, group_size):
248268
w_bf16 = dequant_w4_to_bf16(packed, scale, group_size)
249269
out_dequant = torch.nn.functional.linear(x, w_bf16)
250270

251-
self.assertTrue(
252-
torch.allclose(
253-
out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01
254-
),
255-
f"fused vs dequant M={M} [{N}x{K}]: "
256-
f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}",
257-
)
271+
_assert_snr(self, out_fused, out_dequant, f"fused vs dequant M={M} [{N}x{K}]")
258272

259273
def test_decode(self):
260274
self._run(1, 2048, 2048, 128)

0 commit comments

Comments
 (0)