Skip to content

Commit 505a00a

Browse files
TimDettmersclaude
andauthored
Handle non-contiguous tensors in quantize/dequantize ops (#1859)
* Handle non-contiguous tensors in quantize_4bit and quantize_blockwise (#1342, #1690) Add A.contiguous() calls at the top of quantize_blockwise, quantize_4bit, and their dequantize counterparts in the CUDA backend. The CUDA kernels use raw pointers and assume contiguous memory layout, so non-contiguous inputs (e.g. tensor slices with strides) produced silently incorrect results. Add regression tests verifying non-contiguous tensors produce identical results to their contiguous equivalents for all four ops. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: Fix ruff format violation in test_linear4bit.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c2ae381 commit 505a00a

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _get_col_absmax(
209209

210210
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212+
A = A.contiguous()
212213
torch._check_is_size(blocksize)
213214

214215
if ROCM_WARP_SIZE_64:
@@ -269,6 +270,7 @@ def _(
269270
def _dequantize_blockwise_impl(
270271
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
271272
) -> None:
273+
A = A.contiguous()
272274
if ROCM_WARP_SIZE_64:
273275
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
274276
else:
@@ -303,6 +305,7 @@ def _dequantize_blockwise_impl(
303305
def _(
304306
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
305307
) -> tuple[torch.Tensor, torch.Tensor]:
308+
A = A.contiguous()
306309
if ROCM_WARP_SIZE_64:
307310
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
308311
else:
@@ -385,6 +388,7 @@ def _dequantize_4bit_impl(
385388
dtype: torch.dtype,
386389
out: torch.Tensor,
387390
) -> None:
391+
A = A.contiguous()
388392
if ROCM_WARP_SIZE_64:
389393
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
390394
else:

tests/test_ops.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,108 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
246246
assert out.isreal().all()
247247

248248
opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))
249+
250+
251+
class TestNonContiguousInputs:
252+
"""Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly."""
253+
254+
@pytest.mark.parametrize("device", get_available_devices())
255+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
256+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
257+
def test_quantize_blockwise_non_contiguous(self, device, dtype, blocksize):
258+
if device == "cpu":
259+
pytest.skip("Non-contiguous fix targets CUDA backend only")
260+
261+
code = bitsandbytes.functional.create_dynamic_map().to(device)
262+
263+
# Create non-contiguous tensor via slicing
264+
A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device)
265+
A_noncontig = A_full[:, ::2, :, :]
266+
assert not A_noncontig.is_contiguous()
267+
268+
A_contig = A_noncontig.contiguous()
269+
270+
out_nc, absmax_nc = torch.ops.bitsandbytes.quantize_blockwise(A_noncontig, code, blocksize)
271+
out_c, absmax_c = torch.ops.bitsandbytes.quantize_blockwise(A_contig, code, blocksize)
272+
273+
torch.testing.assert_close(absmax_nc, absmax_c)
274+
torch.testing.assert_close(out_nc, out_c)
275+
276+
@pytest.mark.parametrize("device", get_available_devices())
277+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
278+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
279+
def test_dequantize_blockwise_non_contiguous(self, device, dtype, blocksize):
280+
if device == "cpu":
281+
pytest.skip("Non-contiguous fix targets CUDA backend only")
282+
283+
code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)
284+
285+
# Quantize a contiguous tensor, then create non-contiguous uint8 via transpose
286+
A = torch.randn(1024, 1024, dtype=dtype, device=device)
287+
quantized, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)
288+
289+
# Create non-contiguous uint8 tensor by transposing and transposing back
290+
q_noncontig = quantized.t().t()
291+
# If that's still contiguous, use a different approach
292+
if q_noncontig.is_contiguous():
293+
# Pad and slice to force non-contiguity
294+
q_padded = torch.zeros(1024, 1025, dtype=torch.uint8, device=device)
295+
q_padded[:, :1024] = quantized
296+
q_noncontig = q_padded[:, :1024]
297+
298+
assert not q_noncontig.is_contiguous()
299+
q_contig = q_noncontig.contiguous()
300+
301+
out_nc = torch.ops.bitsandbytes.dequantize_blockwise(q_noncontig, absmax, code, blocksize, dtype)
302+
out_c = torch.ops.bitsandbytes.dequantize_blockwise(q_contig, absmax, code, blocksize, dtype)
303+
304+
torch.testing.assert_close(out_nc, out_c)
305+
306+
@pytest.mark.parametrize("device", get_available_devices())
307+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
308+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
309+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
310+
def test_quantize_4bit_non_contiguous(self, device, dtype, quant_type, blocksize):
311+
if device != "cuda":
312+
pytest.skip("Non-contiguous fix targets CUDA backend only")
313+
314+
# Reproduce issue #1342: non-contiguous tensor from slicing
315+
A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device)
316+
A_noncontig = A_full[:, ::2, :, :]
317+
assert not A_noncontig.is_contiguous()
318+
319+
A_contig = A_noncontig.contiguous()
320+
storage_dtype = torch.uint8
321+
322+
out_nc, absmax_nc = torch.ops.bitsandbytes.quantize_4bit(A_noncontig, blocksize, quant_type, storage_dtype)
323+
out_c, absmax_c = torch.ops.bitsandbytes.quantize_4bit(A_contig, blocksize, quant_type, storage_dtype)
324+
325+
torch.testing.assert_close(absmax_nc, absmax_c)
326+
torch.testing.assert_close(out_nc, out_c)
327+
328+
@pytest.mark.parametrize("device", get_available_devices())
329+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
330+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
331+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
332+
def test_quantize_4bit_roundtrip_non_contiguous(self, device, dtype, quant_type, blocksize):
333+
"""End-to-end test: quantize non-contiguous, dequantize, compare with contiguous path."""
334+
if device != "cuda":
335+
pytest.skip("Non-contiguous fix targets CUDA backend only")
336+
337+
A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device)
338+
A_noncontig = A_full[:, ::2, :, :]
339+
assert not A_noncontig.is_contiguous()
340+
341+
A_contig = A_noncontig.contiguous()
342+
storage_dtype = torch.uint8
343+
344+
# Quantize both
345+
q_nc, absmax_nc = torch.ops.bitsandbytes.quantize_4bit(A_noncontig, blocksize, quant_type, storage_dtype)
346+
q_c, absmax_c = torch.ops.bitsandbytes.quantize_4bit(A_contig, blocksize, quant_type, storage_dtype)
347+
348+
# Dequantize both
349+
shape = A_contig.shape
350+
deq_nc = torch.ops.bitsandbytes.dequantize_4bit(q_nc, absmax_nc, blocksize, quant_type, shape, dtype)
351+
deq_c = torch.ops.bitsandbytes.dequantize_4bit(q_c, absmax_c, blocksize, quant_type, shape, dtype)
352+
353+
torch.testing.assert_close(deq_nc, deq_c)

0 commit comments

Comments
 (0)