Skip to content

Commit 8bd5e49

Browse files
TimDettmersclaude
andcommitted
fix: Replace hard-coded precision thresholds with std-based bounds
Precision tests were flaky because thresholds were set too close to the empirical mean error, leaving insufficient margin for GPU architecture differences. For example, test_4bit_quant for fp4/blocksize=256 used a threshold of 0.2908 + 0.001 = 0.2918, but Blackwell GPUs observed values around 0.2909 — only ~5 sigma from the mean, causing sporadic failures. Collected (mean, std) statistics from 200 samples per configuration on RTX 4090. Thresholds are now set at mean + 7*std, giving ~7 sigma of headroom for the measured GPU and enough margin to accommodate cross-architecture mean shifts (e.g., T4, Blackwell, XPU). Changes in test_functional.py: - test_4bit_quant: error_dict now stores (mean, std) tuples instead of bare means. Removed ad-hoc errtol/reltol special-casing for CPU fp32. - test_gemv_4bit: Replaced complex if/elif threshold tree (with GPU- specific carve-outs like T4 compute cap checks and XPU conditionals) with a clean per-dtype/dim-range (mean, std) table. Individual-sample std is used (not divided by sqrt(iters)) so thresholds naturally accommodate architecture-specific kernel behavior. Changes in test_parametrize.py: - test_replace_parameter_4bit: Same (mean, std) approach as test_4bit_quant. - test_moe_parameter_shape: Replaced flat 0.085/0.25 bounds with measured MoE-tensor-specific (mean, std). - test_different_blocksizes: Same (mean, std) approach as test_4bit_quant. - test_parametrization_forward_method: Replaced flat 0.08/0.25 bounds with small-tensor-specific (mean, std); small 64x64 tensors have ~16x higher relative std than 1024x1024 due to fewer quantization blocks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ed47966 commit 8bd5e49

File tree

2 files changed

+122
-119
lines changed

2 files changed

+122
-119
lines changed

tests/test_functional.py

Lines changed: 85 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,58 +1117,61 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11171117
relerr = (err / (A1.abs().float() + 1e-8)).mean()
11181118
err = err.mean()
11191119

1120-
# The following values were taken from averaging 1k samples per test configuration.
1121-
error_dict = dict()
1122-
error_dict["fp4"] = dict()
1123-
error_dict["nf4"] = dict()
1124-
error_dict["fp4"]["err"] = {
1125-
32: 0.088918,
1126-
64: 0.096545,
1127-
128: 0.102947,
1128-
256: 0.108685,
1129-
512: 0.114087,
1130-
1024: 0.119312,
1131-
2048: 0.124460,
1132-
4096: 0.129573,
1120+
# Expected (mean, std) per configuration, from 200 samples on RTX 4090.
1121+
# Thresholds are set at mean + N_SIGMA * std to avoid flaky failures
1122+
# while still catching real regressions. Worst-case std across dtypes is used.
1123+
N_SIGMA = 7
1124+
error_stats = {
1125+
"fp4": {
1126+
"err": {
1127+
32: (0.088925, 0.000091),
1128+
64: (0.096543, 0.000111),
1129+
128: (0.102969, 0.000134),
1130+
256: (0.108684, 0.000182),
1131+
512: (0.114115, 0.000234),
1132+
1024: (0.119333, 0.000320),
1133+
2048: (0.124556, 0.000455),
1134+
4096: (0.129536, 0.000612),
1135+
},
1136+
"rel_err": {
1137+
32: (0.242443, 0.000330),
1138+
64: (0.260125, 0.000379),
1139+
128: (0.275817, 0.000433),
1140+
256: (0.289831, 0.000497),
1141+
512: (0.302881, 0.000583),
1142+
1024: (0.315000, 0.000757),
1143+
2048: (0.326607, 0.000955),
1144+
4096: (0.337169, 0.001239),
1145+
},
1146+
},
1147+
"nf4": {
1148+
"err": {
1149+
32: (0.067746, 0.000069),
1150+
64: (0.072798, 0.000074),
1151+
128: (0.076831, 0.000091),
1152+
256: (0.080337, 0.000102),
1153+
512: (0.083547, 0.000143),
1154+
1024: (0.086610, 0.000187),
1155+
2048: (0.089592, 0.000251),
1156+
4096: (0.092547, 0.000360),
1157+
},
1158+
"rel_err": {
1159+
32: (0.189726, 0.000304),
1160+
64: (0.203339, 0.000340),
1161+
128: (0.215237, 0.000391),
1162+
256: (0.226105, 0.000398),
1163+
512: (0.236079, 0.000544),
1164+
1024: (0.245370, 0.000600),
1165+
2048: (0.254163, 0.000747),
1166+
4096: (0.262473, 0.000999),
1167+
},
1168+
},
11331169
}
1134-
error_dict["fp4"]["rel_err"] = {
1135-
32: 0.242380,
1136-
64: 0.260130,
1137-
128: 0.275734,
1138-
256: 0.289842,
1139-
512: 0.302852,
1140-
1024: 0.314982,
1141-
2048: 0.326402,
1142-
4096: 0.337228,
1143-
}
1144-
1145-
error_dict["nf4"]["err"] = {
1146-
32: 0.067745,
1147-
64: 0.072792,
1148-
128: 0.076835,
1149-
256: 0.080326,
1150-
512: 0.083535,
1151-
1024: 0.086603,
1152-
2048: 0.089592,
1153-
4096: 0.092537,
1154-
}
1155-
error_dict["nf4"]["rel_err"] = {
1156-
32: 0.189700,
1157-
64: 0.203299,
1158-
128: 0.215252,
1159-
256: 0.226044,
1160-
512: 0.236021,
1161-
1024: 0.245365,
1162-
2048: 0.254146,
1163-
4096: 0.262457,
1164-
}
1165-
1166-
# Allow higher tolerance for fp32 on CPU with larger block sizes
1167-
reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3
1168-
errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3
11691170

1170-
assert err < error_dict[quant_type]["err"][blocksize] + errtol
1171-
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol
1171+
err_mean, err_std = error_stats[quant_type]["err"][blocksize]
1172+
relerr_mean, relerr_std = error_stats[quant_type]["rel_err"][blocksize]
1173+
assert err < err_mean + N_SIGMA * err_std, f"abs error {err:.6f} exceeds {err_mean:.6f} + {N_SIGMA}*{err_std:.6f}"
1174+
assert relerr < relerr_mean + N_SIGMA * relerr_std, f"rel error {relerr:.6f} exceeds {relerr_mean:.6f} + {N_SIGMA}*{relerr_std:.6f}"
11721175

11731176
@pytest.mark.parametrize("device", get_available_devices())
11741177
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@@ -1374,61 +1377,47 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
13741377
relratio = relerr2 / relerr3
13751378
maxratio = relerr2 / relerr3
13761379

1377-
# for debugging if the tests fails
1378-
#
1379-
# print('='*80)
1380-
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
1381-
# print(C1.flatten()[-20:])
1382-
# print(C2.flatten()[-20:])
1383-
# print(f'inference vs training abs: {err1}')
1384-
# print(f'inference vs training rel: {relerr1}')
1385-
# print(f'inference vs training max: {maxerr1}')
1386-
# print(f'inference vs training vs torch err ratio abs: {absratio}')
1387-
# print(f'inference vs training vs torch err ratio rel: {relratio}')
1388-
# print(f'inference vs training vs torch err ratio max: {maxratio}')
1380+
# Expected (mean, std) for err1, relerr1, maxerr1 per dtype/dim group.
1381+
# Measured from 100 iterations × all storage_type/kind/DQ combos on RTX 4090.
1382+
# std is for individual iterations (not the average), so thresholds are generous
1383+
# enough to accommodate GPU architecture differences (e.g., T4, XPU, Blackwell).
1384+
N_SIGMA = 7
1385+
gemv_thresholds = {
1386+
torch.float16: {
1387+
"le512": {"err1": (0.000052, 0.0000063), "relerr1": (0.00024, 0.000357), "maxerr1": (0.00042, 0.0000687)},
1388+
"gt512": {"err1": (0.000018, 0.0000028), "relerr1": (0.00010, 0.000197), "maxerr1": (0.00017, 0.0000179)},
1389+
},
1390+
torch.float32: {
1391+
"le512": {"err1": (2e-8, 2e-9), "relerr1": (8e-7, 1.2e-6), "maxerr1": (6e-8, 2e-8)},
1392+
"gt512": {"err1": (1e-8, 2e-9), "relerr1": (5e-7, 1.6e-7), "maxerr1": (4e-8, 1e-8)},
1393+
},
1394+
torch.bfloat16: {
1395+
"le512": {"err1": (0.00042, 0.000059), "relerr1": (0.0041, 0.01153), "maxerr1": (0.0037, 0.000556)},
1396+
"gt512": {"err1": (0.00014, 0.0000095), "relerr1": (0.0012, 0.000679), "maxerr1": (0.0010, 0.000137)},
1397+
},
1398+
}
1399+
1400+
dim_key = "le512" if dim <= 512 else "gt512"
1401+
thresholds = gemv_thresholds[dtype][dim_key]
1402+
for metric_name, metric_val in [("err1", err1), ("relerr1", relerr1), ("maxerr1", maxerr1)]:
1403+
mean_val, std_val = thresholds[metric_name]
1404+
limit = mean_val + N_SIGMA * std_val
1405+
assert metric_val < limit, (
1406+
f"{metric_name}={metric_val:.8f} exceeds {mean_val:.8f} + {N_SIGMA}*{std_val:.8f} = {limit:.8f} "
1407+
f"for {dtype}, dim={dim}, {storage_type}, DQ={double_quant}, {kind}"
1408+
)
1409+
1410+
# Ratios check that gemv_4bit and matmul_4bit produce consistent results.
1411+
# These are tight bounds on internal consistency, not absolute accuracy.
13891412
if dtype == torch.float16:
1390-
if dim <= 512:
1391-
assert err1 < 7e-5
1392-
1393-
# TODO(matthewdouglas): On T4, dim=128-fp16-fc2-fp4-DQ will have relerror ~ 0.00092727
1394-
if (
1395-
device == "cuda"
1396-
and double_quant
1397-
and storage_type == "fp4"
1398-
and kind == "fc2"
1399-
and torch.cuda.get_device_capability() == (7, 5)
1400-
):
1401-
assert relerr1 < 0.00093
1402-
else:
1403-
assert relerr1 < 0.0008
1404-
else:
1405-
assert err1 < 6e-5
1406-
assert relerr1 < 2e-4
14071413
assert absratio < 1.005 and absratio > 0.995
14081414
assert relratio < 1.005 and relratio > 0.992
14091415
assert maxratio < 1.005 and maxratio > 0.992
14101416
elif dtype == torch.float32:
1411-
if dim <= 512:
1412-
assert err1 < 5e-8
1413-
assert relerr1 < 1e-6
1414-
assert maxerr1 < 1.05e-7
1415-
else:
1416-
assert err1 < 5e-8
1417-
assert relerr1 < 8e-6
1418-
assert maxerr1 < 1e-7
14191417
assert absratio < 1.005 and absratio > 0.995
14201418
assert relratio < 1.005 and relratio > 0.995
14211419
assert maxratio < 1.005 and maxratio > 0.995
14221420
elif dtype == torch.bfloat16:
1423-
if dim <= 512:
1424-
relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007
1425-
assert err1 < 6e-4
1426-
assert relerr1 < relerr_thres
1427-
assert maxerr1 < 0.015
1428-
else:
1429-
assert err1 < 2e-4
1430-
assert relerr1 < 0.002
1431-
assert maxerr1 < 0.0012
14321421
assert absratio < 1.005 and absratio > 0.995
14331422
assert relratio < 1.05 and relratio > 0.96
14341423
assert maxratio < 1.05 and maxratio > 0.97

tests/test_parametrize.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,26 @@ def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics,
7070
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
7171
err_mean = err.mean()
7272

73-
# Expected error bounds from test_functional.py
73+
# Expected (mean, std) from 200 samples on RTX 4090. Worst-case std across dtypes.
74+
# Threshold = mean + N_SIGMA * std avoids flaky failures across GPU architectures.
75+
N_SIGMA = 7
7476
expected_errors = {
7577
"nf4": {
76-
64: {"abs": 0.072792, "rel": 0.203299},
77-
128: {"abs": 0.076835, "rel": 0.215252},
78-
256: {"abs": 0.080326, "rel": 0.226044},
78+
64: {"abs": (0.072796, 0.000072), "rel": (0.203353, 0.000326)},
79+
128: {"abs": (0.076839, 0.000093), "rel": (0.215258, 0.000367)},
80+
256: {"abs": (0.080322, 0.000100), "rel": (0.226056, 0.000392)},
7981
},
8082
"fp4": {
81-
64: {"abs": 0.096545, "rel": 0.260130},
82-
128: {"abs": 0.102947, "rel": 0.275734},
83-
256: {"abs": 0.108685, "rel": 0.289842},
83+
64: {"abs": (0.096547, 0.000112), "rel": (0.260144, 0.000379)},
84+
128: {"abs": (0.102949, 0.000138), "rel": (0.275763, 0.000391)},
85+
256: {"abs": (0.108681, 0.000177), "rel": (0.289835, 0.000507)},
8486
},
8587
}
8688

87-
assert err_mean < expected_errors[quant_type][blocksize]["abs"] + 1e-3, f"Mean abs error {err_mean:.6f} too high"
88-
assert relerr < expected_errors[quant_type][blocksize]["rel"] + 1e-3, f"Mean rel error {relerr:.6f} too high"
89+
abs_mean, abs_std = expected_errors[quant_type][blocksize]["abs"]
90+
rel_mean, rel_std = expected_errors[quant_type][blocksize]["rel"]
91+
assert err_mean < abs_mean + N_SIGMA * abs_std, f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
92+
assert relerr < rel_mean + N_SIGMA * rel_std, f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
8993

9094

9195
@pytest.mark.parametrize("device", get_available_devices())
@@ -120,12 +124,13 @@ def __init__(self, device, dtype):
120124
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
121125
err_mean = err.mean()
122126

123-
# Use slightly looser bounds for higher dimensional tensors
124-
abs_bound = 0.085 # NF4 baseline + margin
125-
rel_bound = 0.25 # NF4 baseline + margin
127+
# Expected (mean, std) for NF4 on MoE-shaped tensors (8x512x256), from 200 samples on RTX 4090.
128+
N_SIGMA = 7
129+
abs_mean, abs_std = 0.072802, 0.000072
130+
rel_mean, rel_std = 0.203327, 0.000312
126131

127-
assert err_mean < abs_bound, f"Mean abs error {err_mean:.6f} too high for shape {param_shape}"
128-
assert relerr < rel_bound, f"Mean rel error {relerr:.6f} too high for shape {param_shape}"
132+
assert err_mean < abs_mean + N_SIGMA * abs_std, f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
133+
assert relerr < rel_mean + N_SIGMA * rel_std, f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
129134

130135

131136
@pytest.mark.parametrize("device", get_available_devices())
@@ -349,14 +354,19 @@ def test_different_blocksizes(device, dtype, blocksize):
349354
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
350355
err_mean = err.mean()
351356

352-
# Expected error bounds from functional tests (using NF4 bounds since that's what we're testing)
353-
expected_abs = {64: 0.072792, 128: 0.076835, 256: 0.080326}
354-
expected_rel = {64: 0.203299, 128: 0.215252, 256: 0.226044}
357+
# Expected (mean, std) for NF4, from 200 samples on RTX 4090. Worst-case std across dtypes.
358+
N_SIGMA = 7
359+
expected_abs = {64: (0.072796, 0.000072), 128: (0.076839, 0.000093), 256: (0.080322, 0.000100)}
360+
expected_rel = {64: (0.203353, 0.000326), 128: (0.215258, 0.000367), 256: (0.226056, 0.000392)}
355361

356-
assert err_mean < expected_abs[blocksize] + 0.01, (
357-
f"Mean abs error {err_mean:.6f} too high for blocksize {blocksize}"
362+
abs_mean, abs_std = expected_abs[blocksize]
363+
rel_mean, rel_std = expected_rel[blocksize]
364+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
365+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f} for blocksize {blocksize}"
366+
)
367+
assert relerr < rel_mean + N_SIGMA * rel_std, (
368+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f} for blocksize {blocksize}"
358369
)
359-
assert relerr < expected_rel[blocksize] + 0.02, f"Mean rel error {relerr:.6f} too high for blocksize {blocksize}"
360370

361371

362372
def test_parametrization_forward_method():
@@ -383,9 +393,13 @@ def test_parametrization_forward_method():
383393
relerr = (err / (original_tensor.abs().float() + 1e-8)).mean()
384394
err_mean = err.mean()
385395

386-
# Use NF4 bounds from functional tests with small margin
387-
assert err_mean < 0.08, f"Mean abs error {err_mean:.6f} too high"
388-
assert relerr < 0.25, f"Mean rel error {relerr:.6f} too high"
396+
# Expected (mean, std) for NF4 on small 64x64 tensor, from 200 samples on RTX 4090.
397+
# Small tensors have higher variance due to fewer blocks in the quantization.
398+
N_SIGMA = 7
399+
abs_mean, abs_std = 0.072842, 0.001180
400+
rel_mean, rel_std = 0.202648, 0.004729
401+
assert err_mean < abs_mean + N_SIGMA * abs_std, f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
402+
assert relerr < rel_mean + N_SIGMA * rel_std, f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
389403

390404

391405
@pytest.mark.parametrize("device", get_available_devices())

0 commit comments

Comments
 (0)