Skip to content

Commit e63e29c

Browse files
Fix 4-bit quantization for weight matrices not divisible by blocksize (#1884)
* fix when A.numel() not divisibel by blocksize * fix * another one
1 parent a5a7f5d commit e63e29c

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def _(
248248

249249
# Quantize with the lookup table
250250
code = CODE[quant_type].to(scaled.device).to(scaled.dtype)
251+
# Pad to even length so packing pairs all elements
252+
if scaled.numel() % 2 != 0:
253+
scaled = torch.nn.functional.pad(scaled, (0, 1), value=0.0)
251254
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8)
252255

253256
# Pack two quantized values per byte
@@ -274,17 +277,20 @@ def _dequantize_4bit_impl(
274277
A = A.reshape(-1)
275278
# Map nf4 to [-1, 1]
276279
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
277-
n = out_dq.numel()
278280
out_dq[1::2] = A & 0xF
279281
out_dq[::2] = A >> 4
280282
# code is fp32, cast to dtype to avoid the mismatch issue
281283
code = CODE[quant_type].to(dtype).to(A.device)
282284
out_dq = code[out_dq]
283285

286+
# Use the actual output size, not the unpacked size (which may include padding)
287+
n = 1
288+
for s in shape:
289+
n *= s
290+
# Trim any extra elements from padding during quantization
291+
out_dq = out_dq[:n]
292+
284293
# Apply scales
285-
if out_dq.numel() != n:
286-
assert out_dq.numel() == n + 1
287-
out_dq = torch.narrow(out_dq, 0, 0, n)
288294
blocks = n // blocksize
289295
blocks += 1 if n % blocksize > 0 else 0
290296
rem = n % blocksize

bitsandbytes/backends/triton/ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,12 @@ def quantize_4bit(
7676

7777
n = A.numel()
7878

79-
# TODO: Support when weight matrix is not divisible by blocksize
80-
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
79+
# Pad to next multiple of blocksize so the kernel always processes full blocks
80+
remainder = n % blocksize
81+
if remainder != 0:
82+
padding = blocksize - remainder
83+
A = torch.nn.functional.pad(A.view(-1), (0, padding), value=0.0)
84+
n = A.numel()
8185

8286
blocks = -(n // -(blocksize * 2))
8387

tests/test_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,32 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
172172

173173
opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))
174174

175+
@pytest.mark.parametrize("device", get_available_devices())
176+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
177+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
178+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
179+
def test_quantize_4bit_not_divisible_by_blocksize(self, device, dtype, quant_type, blocksize):
180+
"""Test quantize/dequantize roundtrip when n_elements is not divisible by blocksize."""
181+
# Shape chosen so numel is NOT divisible by blocksize
182+
shape = (7, blocksize - 1)
183+
A = torch.randn(shape, dtype=dtype, device=device)
184+
storage_dtype = torch.uint8
185+
186+
# Should not raise
187+
packed, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
188+
189+
assert packed.device == A.device
190+
assert absmax.device == A.device
191+
192+
# Dequantize back and verify shape is preserved
193+
out = torch.ops.bitsandbytes.dequantize_4bit(packed, absmax, blocksize, quant_type, shape, dtype)
194+
195+
assert out.shape == shape
196+
assert out.dtype == dtype
197+
198+
# Verify output is finite (no NaN/Inf)
199+
assert torch.isfinite(out).all(), "Dequantized output contains NaN or Inf"
200+
175201
@pytest.mark.parametrize("device", get_available_devices())
176202
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
177203
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))

0 commit comments

Comments
 (0)