Skip to content

Commit 5ea3c89

Browse files
TimDettmersclaude
andcommitted
Remove redundant quant_storage parametrization from test_gemv_4bit
quant_storage is a dtype view over byte-identical packed 4-bit data, introduced for FSDP sharding compatibility. The gemv kernel reads raw bytes via void pointer and never branches on B.dtype, so testing 4 quant_storage values exercises the same code path 4 times. Fix test_gemv_4bit to use only uint8 (the default), reducing from 1,536 to 384 test cases (~18 min saved). Add a dedicated test_quant_storage_shard_roundtrip that properly tests what quant_storage is for: verifying bytes survive FSDP-style flatten/chunk/reassemble with all 4 storage dtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2d22247 commit 5ea3c89

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

tests/test_functional.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,13 +1281,9 @@ def test_bench_4bit_dequant(self, quant_type):
12811281
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
12821282
@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
12831283
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1284-
@pytest.mark.parametrize(
1285-
"quant_storage",
1286-
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
1287-
ids=describe_dtype,
1288-
)
12891284
@pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim"))
1290-
def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind):
1285+
def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
1286+
quant_storage = torch.uint8
12911287
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype, quant_storage):
12921288
pytest.skip("This configuration is not supported on HPU.")
12931289

tests/test_linear4bit.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,42 @@ def test_params4bit_torch_chunk_split(device, quant_type):
248248
assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
249249

250250

251+
@pytest.mark.parametrize("device", get_available_devices())
252+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
253+
@pytest.mark.parametrize(
254+
"quant_storage",
255+
[torch.uint8, torch.float16, torch.bfloat16, torch.float32],
256+
ids=describe_dtype,
257+
)
258+
def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
259+
"""Test that quantized weights survive a flatten-chunk-reassemble roundtrip.
260+
261+
Non-uint8 quant_storage exists so that FSDP can shard quantized tensors
262+
without splitting packed 4-bit pairs. This test simulates FSDP's
263+
shard/gather pattern and verifies numerical correctness after reassembly.
264+
"""
265+
M, K = 256, 128
266+
A = torch.randn(1, K, dtype=torch.float16, device=device)
267+
B = torch.randn(M, K, dtype=torch.float16, device=device)
268+
269+
qB, state = bnb.functional.quantize_4bit(B, quant_type=quant_type, quant_storage=quant_storage)
270+
ref = bnb.functional.gemv_4bit(A, qB.t(), state=state)
271+
272+
# Simulate FSDP: flatten, split into shards, reassemble
273+
flat = qB.flatten()
274+
n_shards = 4
275+
shards = flat.chunk(n_shards)
276+
reassembled = torch.cat(shards).reshape(qB.shape)
277+
278+
assert reassembled.dtype == qB.dtype
279+
assert torch.equal(
280+
reassembled.view(torch.uint8), qB.view(torch.uint8)
281+
), "Bytes changed after shard roundtrip"
282+
283+
out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state)
284+
torch.testing.assert_close(out, ref)
285+
286+
251287
@pytest.mark.parametrize("device", get_available_devices())
252288
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
253289
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])

0 commit comments

Comments
 (0)