Skip to content

Commit 6517a70

Browse files
TimDettmersclaude
andcommitted
style: Fix ruff lint and format violations
- Replace ambiguous unicode multiplication sign with ASCII x - Apply ruff format to long assert lines - Fix test_linear4bit.py pre-existing format violation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8bd5e49 commit 6517a70

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

tests/test_functional.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,8 +1170,12 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11701170

11711171
err_mean, err_std = error_stats[quant_type]["err"][blocksize]
11721172
relerr_mean, relerr_std = error_stats[quant_type]["rel_err"][blocksize]
1173-
assert err < err_mean + N_SIGMA * err_std, f"abs error {err:.6f} exceeds {err_mean:.6f} + {N_SIGMA}*{err_std:.6f}"
1174-
assert relerr < relerr_mean + N_SIGMA * relerr_std, f"rel error {relerr:.6f} exceeds {relerr_mean:.6f} + {N_SIGMA}*{relerr_std:.6f}"
1173+
assert err < err_mean + N_SIGMA * err_std, (
1174+
f"abs error {err:.6f} exceeds {err_mean:.6f} + {N_SIGMA}*{err_std:.6f}"
1175+
)
1176+
assert relerr < relerr_mean + N_SIGMA * relerr_std, (
1177+
f"rel error {relerr:.6f} exceeds {relerr_mean:.6f} + {N_SIGMA}*{relerr_std:.6f}"
1178+
)
11751179

11761180
@pytest.mark.parametrize("device", get_available_devices())
11771181
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@@ -1378,14 +1382,22 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
13781382
maxratio = relerr2 / relerr3
13791383

13801384
# Expected (mean, std) for err1, relerr1, maxerr1 per dtype/dim group.
1381-
# Measured from 100 iterations × all storage_type/kind/DQ combos on RTX 4090.
1385+
# Measured from 100 iterations x all storage_type/kind/DQ combos on RTX 4090.
13821386
# std is for individual iterations (not the average), so thresholds are generous
13831387
# enough to accommodate GPU architecture differences (e.g., T4, XPU, Blackwell).
13841388
N_SIGMA = 7
13851389
gemv_thresholds = {
13861390
torch.float16: {
1387-
"le512": {"err1": (0.000052, 0.0000063), "relerr1": (0.00024, 0.000357), "maxerr1": (0.00042, 0.0000687)},
1388-
"gt512": {"err1": (0.000018, 0.0000028), "relerr1": (0.00010, 0.000197), "maxerr1": (0.00017, 0.0000179)},
1391+
"le512": {
1392+
"err1": (0.000052, 0.0000063),
1393+
"relerr1": (0.00024, 0.000357),
1394+
"maxerr1": (0.00042, 0.0000687),
1395+
},
1396+
"gt512": {
1397+
"err1": (0.000018, 0.0000028),
1398+
"relerr1": (0.00010, 0.000197),
1399+
"maxerr1": (0.00017, 0.0000179),
1400+
},
13891401
},
13901402
torch.float32: {
13911403
"le512": {"err1": (2e-8, 2e-9), "relerr1": (8e-7, 1.2e-6), "maxerr1": (6e-8, 2e-8)},

tests/test_linear4bit.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
276276
reassembled = torch.cat(shards).reshape(qB.shape)
277277

278278
assert reassembled.dtype == qB.dtype
279-
assert torch.equal(
280-
reassembled.view(torch.uint8), qB.view(torch.uint8)
281-
), "Bytes changed after shard roundtrip"
279+
assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip"
282280

283281
out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state)
284282
torch.testing.assert_close(out, ref)

tests/test_parametrize.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics,
8888

8989
abs_mean, abs_std = expected_errors[quant_type][blocksize]["abs"]
9090
rel_mean, rel_std = expected_errors[quant_type][blocksize]["rel"]
91-
assert err_mean < abs_mean + N_SIGMA * abs_std, f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
92-
assert relerr < rel_mean + N_SIGMA * rel_std, f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
91+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
92+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
93+
)
94+
assert relerr < rel_mean + N_SIGMA * rel_std, (
95+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
96+
)
9397

9498

9599
@pytest.mark.parametrize("device", get_available_devices())
@@ -129,8 +133,12 @@ def __init__(self, device, dtype):
129133
abs_mean, abs_std = 0.072802, 0.000072
130134
rel_mean, rel_std = 0.203327, 0.000312
131135

132-
assert err_mean < abs_mean + N_SIGMA * abs_std, f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
133-
assert relerr < rel_mean + N_SIGMA * rel_std, f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
136+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
137+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
138+
)
139+
assert relerr < rel_mean + N_SIGMA * rel_std, (
140+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
141+
)
134142

135143

136144
@pytest.mark.parametrize("device", get_available_devices())
@@ -398,8 +406,12 @@ def test_parametrization_forward_method():
398406
N_SIGMA = 7
399407
abs_mean, abs_std = 0.072842, 0.001180
400408
rel_mean, rel_std = 0.202648, 0.004729
401-
assert err_mean < abs_mean + N_SIGMA * abs_std, f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
402-
assert relerr < rel_mean + N_SIGMA * rel_std, f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
409+
assert err_mean < abs_mean + N_SIGMA * abs_std, (
410+
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
411+
)
412+
assert relerr < rel_mean + N_SIGMA * rel_std, (
413+
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
414+
)
403415

404416

405417
@pytest.mark.parametrize("device", get_available_devices())

0 commit comments

Comments
 (0)