Skip to content

Commit a5a7f5d

Browse files
TimDettmersclaude
andauthored
fix: Replace hard-coded precision thresholds with std-based bounds (#1864)
* Update coordinator guide: run only relevant tests, not full suite Worker agents were running the full test suite (10+ min) which is wasteful when only a small area of code changed. Updated the completion workflow to instruct agents to run only relevant test files/functions. The full suite will be run separately later. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: Replace hard-coded precision thresholds with std-based bounds Precision tests were flaky because thresholds were set too close to the empirical mean error, leaving insufficient margin for GPU architecture differences. For example, test_4bit_quant for fp4/blocksize=256 used a threshold of 0.2908 + 0.001 = 0.2918, but Blackwell GPUs observed values around 0.2909 — only ~5 sigma from the mean, causing sporadic failures. Collected (mean, std) statistics from 200 samples per configuration on RTX 4090. Thresholds are now set at mean + 7*std, giving ~7 sigma of headroom for the measured GPU and enough margin to accommodate cross-architecture mean shifts (e.g., T4, Blackwell, XPU). Changes in test_functional.py: - test_4bit_quant: error_dict now stores (mean, std) tuples instead of bare means. Removed ad-hoc errtol/reltol special-casing for CPU fp32. - test_gemv_4bit: Replaced complex if/elif threshold tree (with GPU- specific carve-outs like T4 compute cap checks and XPU conditionals) with a clean per-dtype/dim-range (mean, std) table. Individual-sample std is used (not divided by sqrt(iters)) so thresholds naturally accommodate architecture-specific kernel behavior. Changes in test_parametrize.py: - test_replace_parameter_4bit: Same (mean, std) approach as test_4bit_quant. - test_moe_parameter_shape: Replaced flat 0.085/0.25 bounds with measured MoE-tensor-specific (mean, std). - test_different_blocksizes: Same (mean, std) approach as test_4bit_quant. - test_parametrization_forward_method: Replaced flat 0.08/0.25 bounds with small-tensor-specific (mean, std); small 64x64 tensors have ~16x higher relative std than 1024x1024 due to fewer quantization blocks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: Fix ruff lint and format violations - Replace ambiguous unicode multiplication sign with ASCII x - Apply ruff format to long assert lines - Fix test_linear4bit.py pre-existing format violation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d77e01c commit a5a7f5d

File tree

2 files changed

+146
-119
lines changed

2 files changed

+146
-119
lines changed

tests/test_functional.py

Lines changed: 97 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -865,58 +865,65 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
865865
relerr = (err / (A1.abs().float() + 1e-8)).mean()
866866
err = err.mean()
867867

868-
# The following values were taken from averaging 1k samples per test configuration.
869-
error_dict = dict()
870-
error_dict["fp4"] = dict()
871-
error_dict["nf4"] = dict()
872-
error_dict["fp4"]["err"] = {
873-
32: 0.088918,
874-
64: 0.096545,
875-
128: 0.102947,
876-
256: 0.108685,
877-
512: 0.114087,
878-
1024: 0.119312,
879-
2048: 0.124460,
880-
4096: 0.129573,
868+
# Expected (mean, std) per configuration, from 200 samples on RTX 4090.
869+
# Thresholds are set at mean + N_SIGMA * std to avoid flaky failures
870+
# while still catching real regressions. Worst-case std across dtypes is used.
871+
N_SIGMA = 7
872+
error_stats = {
873+
"fp4": {
874+
"err": {
875+
32: (0.088925, 0.000091),
876+
64: (0.096543, 0.000111),
877+
128: (0.102969, 0.000134),
878+
256: (0.108684, 0.000182),
879+
512: (0.114115, 0.000234),
880+
1024: (0.119333, 0.000320),
881+
2048: (0.124556, 0.000455),
882+
4096: (0.129536, 0.000612),
883+
},
884+
"rel_err": {
885+
32: (0.242443, 0.000330),
886+
64: (0.260125, 0.000379),
887+
128: (0.275817, 0.000433),
888+
256: (0.289831, 0.000497),
889+
512: (0.302881, 0.000583),
890+
1024: (0.315000, 0.000757),
891+
2048: (0.326607, 0.000955),
892+
4096: (0.337169, 0.001239),
893+
},
894+
},
895+
"nf4": {
896+
"err": {
897+
32: (0.067746, 0.000069),
898+
64: (0.072798, 0.000074),
899+
128: (0.076831, 0.000091),
900+
256: (0.080337, 0.000102),
901+
512: (0.083547, 0.000143),
902+
1024: (0.086610, 0.000187),
903+
2048: (0.089592, 0.000251),
904+
4096: (0.092547, 0.000360),
905+
},
906+
"rel_err": {
907+
32: (0.189726, 0.000304),
908+
64: (0.203339, 0.000340),
909+
128: (0.215237, 0.000391),
910+
256: (0.226105, 0.000398),
911+
512: (0.236079, 0.000544),
912+
1024: (0.245370, 0.000600),
913+
2048: (0.254163, 0.000747),
914+
4096: (0.262473, 0.000999),
915+
},
916+
},
881917
}
882-
error_dict["fp4"]["rel_err"] = {
883-
32: 0.242380,
884-
64: 0.260130,
885-
128: 0.275734,
886-
256: 0.289842,
887-
512: 0.302852,
888-
1024: 0.314982,
889-
2048: 0.326402,
890-
4096: 0.337228,
891-
}
892-
893-
error_dict["nf4"]["err"] = {
894-
32: 0.067745,
895-
64: 0.072792,
896-
128: 0.076835,
897-
256: 0.080326,
898-
512: 0.083535,
899-
1024: 0.086603,
900-
2048: 0.089592,
901-
4096: 0.092537,
902-
}
903-
error_dict["nf4"]["rel_err"] = {
904-
32: 0.189700,
905-
64: 0.203299,
906-
128: 0.215252,
907-
256: 0.226044,
908-
512: 0.236021,
909-
1024: 0.245365,
910-
2048: 0.254146,
911-
4096: 0.262457,
912-
}
913-
914-
# Allow higher tolerance for fp32 on CPU with larger block sizes
915-
reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3
916-
errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3
917918

918-
assert err < error_dict[quant_type]["err"][blocksize] + errtol
919-
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol
919+
err_mean, err_std = error_stats[quant_type]["err"][blocksize]
920+
relerr_mean, relerr_std = error_stats[quant_type]["rel_err"][blocksize]
921+
assert err < err_mean + N_SIGMA * err_std, (
922+
f"abs error {err:.6f} exceeds {err_mean:.6f} + {N_SIGMA}*{err_std:.6f}"
923+
)
924+
assert relerr < relerr_mean + N_SIGMA * relerr_std, (
925+
f"rel error {relerr:.6f} exceeds {relerr_mean:.6f} + {N_SIGMA}*{relerr_std:.6f}"
926+
)
920927

921928
@pytest.mark.parametrize("device", get_available_devices())
922929
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@@ -1122,61 +1129,55 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
11221129
relratio = relerr2 / relerr3
11231130
maxratio = relerr2 / relerr3
11241131

1125-
# for debugging if the tests fails
1126-
#
1127-
# print('='*80)
1128-
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
1129-
# print(C1.flatten()[-20:])
1130-
# print(C2.flatten()[-20:])
1131-
# print(f'inference vs training abs: {err1}')
1132-
# print(f'inference vs training rel: {relerr1}')
1133-
# print(f'inference vs training max: {maxerr1}')
1134-
# print(f'inference vs training vs torch err ratio abs: {absratio}')
1135-
# print(f'inference vs training vs torch err ratio rel: {relratio}')
1136-
# print(f'inference vs training vs torch err ratio max: {maxratio}')
1132+
# Expected (mean, std) for err1, relerr1, maxerr1 per dtype/dim group.
1133+
# Measured from 100 iterations x all storage_type/kind/DQ combos on RTX 4090.
1134+
# std is for individual iterations (not the average), so thresholds are generous
1135+
# enough to accommodate GPU architecture differences (e.g., T4, XPU, Blackwell).
1136+
N_SIGMA = 7
1137+
gemv_thresholds = {
1138+
torch.float16: {
1139+
"le512": {
1140+
"err1": (0.000052, 0.0000063),
1141+
"relerr1": (0.00024, 0.000357),
1142+
"maxerr1": (0.00042, 0.0000687),
1143+
},
1144+
"gt512": {
1145+
"err1": (0.000018, 0.0000028),
1146+
"relerr1": (0.00010, 0.000197),
1147+
"maxerr1": (0.00017, 0.0000179),
1148+
},
1149+
},
1150+
torch.float32: {
1151+
"le512": {"err1": (2e-8, 2e-9), "relerr1": (8e-7, 1.2e-6), "maxerr1": (6e-8, 2e-8)},
1152+
"gt512": {"err1": (1e-8, 2e-9), "relerr1": (5e-7, 1.6e-7), "maxerr1": (4e-8, 1e-8)},
1153+
},
1154+
torch.bfloat16: {
1155+
"le512": {"err1": (0.00042, 0.000059), "relerr1": (0.0041, 0.01153), "maxerr1": (0.0037, 0.000556)},
1156+
"gt512": {"err1": (0.00014, 0.0000095), "relerr1": (0.0012, 0.000679), "maxerr1": (0.0010, 0.000137)},
1157+
},
1158+
}
1159+
1160+
dim_key = "le512" if dim <= 512 else "gt512"
1161+
thresholds = gemv_thresholds[dtype][dim_key]
1162+
for metric_name, metric_val in [("err1", err1), ("relerr1", relerr1), ("maxerr1", maxerr1)]:
1163+
mean_val, std_val = thresholds[metric_name]
1164+
limit = mean_val + N_SIGMA * std_val
1165+
assert metric_val < limit, (
1166+
f"{metric_name}={metric_val:.8f} exceeds {mean_val:.8f} + {N_SIGMA}*{std_val:.8f} = {limit:.8f} "
1167+
f"for {dtype}, dim={dim}, {storage_type}, DQ={double_quant}, {kind}"
1168+
)
1169+
1170+
# Ratios check that gemv_4bit and matmul_4bit produce consistent results.
1171+
# These are tight bounds on internal consistency, not absolute accuracy.
11371172
if dtype == torch.float16:
1138-
if dim <= 512:
1139-
assert err1 < 7e-5
1140-
1141-
# TODO(matthewdouglas): On T4, dim=128-fp16-fc2-fp4-DQ will have relerror ~ 0.00092727
1142-
if (
1143-
device == "cuda"
1144-
and double_quant
1145-
and storage_type == "fp4"
1146-
and kind == "fc2"
1147-
and torch.cuda.get_device_capability() == (7, 5)
1148-
):
1149-
assert relerr1 < 0.00093
1150-
else:
1151-
assert relerr1 < 0.0008
1152-
else:
1153-
assert err1 < 6e-5
1154-
assert relerr1 < 2e-4
11551173
assert absratio < 1.005 and absratio > 0.995
11561174
assert relratio < 1.005 and relratio > 0.992
11571175
assert maxratio < 1.005 and maxratio > 0.992
11581176
elif dtype == torch.float32:
1159-
if dim <= 512:
1160-
assert err1 < 5e-8
1161-
assert relerr1 < 1e-6
1162-
assert maxerr1 < 1.05e-7
1163-
else:
1164-
assert err1 < 5e-8
1165-
assert relerr1 < 8e-6
1166-
assert maxerr1 < 1e-7
11671177
assert absratio < 1.005 and absratio > 0.995
11681178
assert relratio < 1.005 and relratio > 0.995
11691179
assert maxratio < 1.005 and maxratio > 0.995
11701180
elif dtype == torch.bfloat16:
1171-
if dim <= 512:
1172-
relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007
1173-
assert err1 < 6e-4
1174-
assert relerr1 < relerr_thres
1175-
assert maxerr1 < 0.015
1176-
else:
1177-
assert err1 < 2e-4
1178-
assert relerr1 < 0.002
1179-
assert maxerr1 < 0.0012
11801181
assert absratio < 1.005 and absratio > 0.995
11811182
assert relratio < 1.05 and relratio > 0.96
11821183
assert maxratio < 1.05 and maxratio > 0.97

tests/test_parametrize.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,30 @@ def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics,
6767
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
6868
err_mean = err.mean()
6969

70-
# Expected error bounds from test_functional.py
70+
# Expected (mean, std) from 200 samples on RTX 4090. Worst-case std across dtypes.
71+
# Threshold = mean + N_SIGMA * std avoids flaky failures across GPU architectures.
72+
N_SIGMA = 7
7173
expected_errors = {
7274
"nf4": {
73-
64: {"abs": 0.072792, "rel": 0.203299},
74-
128: {"abs": 0.076835, "rel": 0.215252},
75-
256: {"abs": 0.080326, "rel": 0.226044},
75+
64: {"abs": (0.072796, 0.000072), "rel": (0.203353, 0.000326)},
76+
128: {"abs": (0.076839, 0.000093), "rel": (0.215258, 0.000367)},
77+
256: {"abs": (0.080322, 0.000100), "rel": (0.226056, 0.000392)},
7678
},
7779
"fp4": {
78-
64: {"abs": 0.096545, "rel": 0.260130},
79-
128: {"abs": 0.102947, "rel": 0.275734},
80-
256: {"abs": 0.108685, "rel": 0.289842},
80+
64: {"abs": (0.096547, 0.000112), "rel": (0.260144, 0.000379)},
81+
128: {"abs": (0.102949, 0.000138), "rel": (0.275763, 0.000391)},
82+
256: {"abs": (0.108681, 0.000177), "rel": (0.289835, 0.000507)},
8183
},
8284
}
8385

84-
assert err_mean < expected_errors[quant_type][blocksize]["abs"] + 1e-3, f"Mean abs error {err_mean:.6f} too high"
85-
assert relerr < expected_errors[quant_type][blocksize]["rel"] + 1e-3, f"Mean rel error {relerr:.6f} too high"
86+
abs_mean, abs_std = expected_errors[quant_type][blocksize]["abs"]
87+
rel_mean, rel_std = expected_errors[quant_type][blocksize]["rel"]
88+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
89+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
90+
)
91+
assert relerr < rel_mean + N_SIGMA * rel_std, (
92+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
93+
)
8694

8795

8896
@pytest.mark.parametrize("device", get_available_devices())
@@ -117,12 +125,17 @@ def __init__(self, device, dtype):
117125
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
118126
err_mean = err.mean()
119127

120-
# Use slightly looser bounds for higher dimensional tensors
121-
abs_bound = 0.085 # NF4 baseline + margin
122-
rel_bound = 0.25 # NF4 baseline + margin
128+
# Expected (mean, std) for NF4 on MoE-shaped tensors (8x512x256), from 200 samples on RTX 4090.
129+
N_SIGMA = 7
130+
abs_mean, abs_std = 0.072802, 0.000072
131+
rel_mean, rel_std = 0.203327, 0.000312
123132

124-
assert err_mean < abs_bound, f"Mean abs error {err_mean:.6f} too high for shape {param_shape}"
125-
assert relerr < rel_bound, f"Mean rel error {relerr:.6f} too high for shape {param_shape}"
133+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
134+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
135+
)
136+
assert relerr < rel_mean + N_SIGMA * rel_std, (
137+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
138+
)
126139

127140

128141
@pytest.mark.parametrize("device", get_available_devices())
@@ -346,14 +359,19 @@ def test_different_blocksizes(device, dtype, blocksize):
346359
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
347360
err_mean = err.mean()
348361

349-
# Expected error bounds from functional tests (using NF4 bounds since that's what we're testing)
350-
expected_abs = {64: 0.072792, 128: 0.076835, 256: 0.080326}
351-
expected_rel = {64: 0.203299, 128: 0.215252, 256: 0.226044}
362+
# Expected (mean, std) for NF4, from 200 samples on RTX 4090. Worst-case std across dtypes.
363+
N_SIGMA = 7
364+
expected_abs = {64: (0.072796, 0.000072), 128: (0.076839, 0.000093), 256: (0.080322, 0.000100)}
365+
expected_rel = {64: (0.203353, 0.000326), 128: (0.215258, 0.000367), 256: (0.226056, 0.000392)}
352366

353-
assert err_mean < expected_abs[blocksize] + 0.01, (
354-
f"Mean abs error {err_mean:.6f} too high for blocksize {blocksize}"
367+
abs_mean, abs_std = expected_abs[blocksize]
368+
rel_mean, rel_std = expected_rel[blocksize]
369+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
370+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f} for blocksize {blocksize}"
371+
)
372+
assert relerr < rel_mean + N_SIGMA * rel_std, (
373+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f} for blocksize {blocksize}"
355374
)
356-
assert relerr < expected_rel[blocksize] + 0.02, f"Mean rel error {relerr:.6f} too high for blocksize {blocksize}"
357375

358376

359377
def test_parametrization_forward_method():
@@ -380,9 +398,17 @@ def test_parametrization_forward_method():
380398
relerr = (err / (original_tensor.abs().float() + 1e-8)).mean()
381399
err_mean = err.mean()
382400

383-
# Use NF4 bounds from functional tests with small margin
384-
assert err_mean < 0.08, f"Mean abs error {err_mean:.6f} too high"
385-
assert relerr < 0.25, f"Mean rel error {relerr:.6f} too high"
401+
# Expected (mean, std) for NF4 on small 64x64 tensor, from 200 samples on RTX 4090.
402+
# Small tensors have higher variance due to fewer blocks in the quantization.
403+
N_SIGMA = 7
404+
abs_mean, abs_std = 0.072842, 0.001180
405+
rel_mean, rel_std = 0.202648, 0.004729
406+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
407+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
408+
)
409+
assert relerr < rel_mean + N_SIGMA * rel_std, (
410+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
411+
)
386412

387413

388414
@pytest.mark.parametrize("device", get_available_devices())

0 commit comments

Comments
 (0)