Skip to content

Commit d1f3d75

Browse files
TimDettmersclaude
andcommitted
Add out parameter to dequantize_kbit for CUDA graph compatibility
Factor dequant into _dequantize_kbit_impl that accepts a pre-allocated output tensor. Add dequantize_kbit_ in-place op variant following the existing pattern (dequantize_4bit.out, gemv_4bit.out). The public API dequantize_kbit() now accepts an optional out parameter — if provided, the kernel writes into it directly instead of allocating, which is required for CUDA graph replay. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f95a7f2 commit d1f3d75

File tree

5 files changed

+189
-8
lines changed

5 files changed

+189
-8
lines changed

bitsandbytes/_ops.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,30 @@ def _(
475475
)
476476
num_blocks = -(n // -32)
477477
return torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
478+
479+
480+
torch.library.define(
481+
"bitsandbytes::dequantize_kbit_",
482+
"(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)",
483+
)
484+
485+
486+
@register_fake("bitsandbytes::dequantize_kbit_")
487+
def _(
488+
packed: torch.Tensor,
489+
codebook: torch.Tensor,
490+
absmax: torch.Tensor,
491+
k: int,
492+
n: int,
493+
dtype: torch.dtype,
494+
out: torch.Tensor,
495+
) -> torch.Tensor:
496+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
497+
torch._check(
498+
absmax.dtype in (torch.float32, torch.uint8),
499+
lambda: f"absmax must be float32 or uint8 (E4M4), got {absmax.dtype}",
500+
)
501+
num_blocks = -(n // -32)
502+
torch._check(out.numel() >= num_blocks * 32, lambda: f"out must have at least {num_blocks * 32} elements")
503+
torch._check(out.dtype == dtype, lambda: f"out dtype {out.dtype} must match requested dtype {dtype}")
504+
return out

bitsandbytes/backends/cuda/ops.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -810,15 +810,15 @@ def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, to
810810
}
811811

812812

813-
@register_kernel("bitsandbytes::dequantize_kbit", "cuda")
814-
def _(
813+
def _dequantize_kbit_impl(
815814
packed: torch.Tensor,
816815
codebook: torch.Tensor,
817816
absmax: torch.Tensor,
818817
k: int,
819818
n: int,
820819
dtype: torch.dtype,
821-
) -> torch.Tensor:
820+
out: torch.Tensor,
821+
) -> None:
822822
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
823823
torch._check(
824824
dtype in _KBIT_DTYPE_SUFFIX,
@@ -836,9 +836,6 @@ def _(
836836

837837
absmax = encode_absmax_e4m4(absmax)
838838

839-
num_blocks = -(n // -32)
840-
out = torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
841-
842839
tname = _KBIT_DTYPE_SUFFIX[dtype]
843840
aname = _KBIT_ABSMAX_SUFFIX[absmax.dtype]
844841

@@ -853,4 +850,31 @@ def _(
853850
_get_tensor_stream(packed),
854851
)
855852

853+
854+
@register_kernel("bitsandbytes::dequantize_kbit", "cuda")
855+
def _(
856+
packed: torch.Tensor,
857+
codebook: torch.Tensor,
858+
absmax: torch.Tensor,
859+
k: int,
860+
n: int,
861+
dtype: torch.dtype,
862+
) -> torch.Tensor:
863+
num_blocks = -(n // -32)
864+
out = torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
865+
_dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out)
866+
return out
867+
868+
869+
@register_kernel("bitsandbytes::dequantize_kbit_", "cuda")
870+
def _(
871+
packed: torch.Tensor,
872+
codebook: torch.Tensor,
873+
absmax: torch.Tensor,
874+
k: int,
875+
n: int,
876+
dtype: torch.dtype,
877+
out: torch.Tensor,
878+
) -> torch.Tensor:
879+
_dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out)
856880
return out

bitsandbytes/functional.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,7 @@ def dequantize_kbit(
11791179
k: int,
11801180
n: int,
11811181
dtype: torch.dtype = torch.float16,
1182+
out: Optional[Tensor] = None,
11821183
) -> Tensor:
11831184
"""Dequantize a k-bit blockwise quantized tensor.
11841185
@@ -1190,12 +1191,25 @@ def dequantize_kbit(
11901191
k: Bit width (2, 3, 4, or 5).
11911192
n: Number of original elements.
11921193
dtype: Output dtype. Defaults to float16.
1194+
out: Optional pre-allocated output tensor for CUDA graph compatibility.
1195+
Must have at least ceil(n/32)*32 elements and matching dtype.
11931196
11941197
Returns:
11951198
Dequantized tensor of shape (n,) with the given dtype.
11961199
"""
1197-
out = torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, dtype)
1198-
return out[:n]
1200+
num_blocks = -(n // -32)
1201+
padded_n = num_blocks * 32
1202+
1203+
if out is not None:
1204+
if out.numel() < padded_n:
1205+
raise ValueError(f"out tensor has {out.numel()} elements, need at least {padded_n}")
1206+
if out.dtype != dtype:
1207+
raise ValueError(f"out dtype {out.dtype} does not match requested dtype {dtype}")
1208+
torch.ops.bitsandbytes.dequantize_kbit_(packed, codebook, absmax, k, n, dtype, out)
1209+
return out[:n]
1210+
1211+
result = torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, dtype)
1212+
return result[:n]
11991213

12001214

12011215
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)

spec.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Spec: Add `out` parameter to kbit dequantize for CUDA graph compatibility
2+
3+
## Problem
4+
5+
`dequantize_kbit` allocates a fresh output tensor on every call. This breaks
6+
CUDA graph capture, which requires kernels to write to the same memory address
7+
on every replay. The dequant is on the inference hot path and needs graph support.
8+
9+
## Changes
10+
11+
### 1. CUDA backend (`bitsandbytes/backends/cuda/ops.py`)
12+
13+
Factor the kernel call into `_dequantize_kbit_impl(packed, codebook, absmax, k, n, dtype, out)`:
14+
- Accepts a pre-allocated `out` tensor
15+
- Validates `out` shape, dtype, device
16+
- Calls the C kernel writing into `out`
17+
18+
The existing `dequantize_kbit` registered kernel allocates `out` then calls `_impl`.
19+
20+
### 2. torch op definition (`bitsandbytes/_ops.py`)
21+
22+
Add a second op `bitsandbytes::dequantize_kbit_` (in-place variant with trailing
23+
underscore, matching existing pattern for `dequantize_4bit`):
24+
- Signature: `(Tensor packed, Tensor codebook, Tensor absmax, int k, int n, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)`
25+
- Fake implementation validates shapes, returns `out`
26+
27+
### 3. Public API (`bitsandbytes/functional.py`)
28+
29+
Add optional `out` parameter to `dequantize_kbit()`:
30+
- `out: Optional[Tensor] = None`
31+
- If provided, validate shape/dtype/device, pass to impl
32+
- If None, allocate as before
33+
34+
### 4. Tests
35+
36+
Add test cases in `tests/test_kbit_quantization.py`:
37+
- Dequant with pre-allocated `out` tensor matches normal dequant
38+
- `out` tensor with wrong shape raises error
39+
- `out` tensor with wrong dtype raises error
40+
41+
## Files touched
42+
43+
- `bitsandbytes/backends/cuda/ops.py`
44+
- `bitsandbytes/_ops.py`
45+
- `bitsandbytes/functional.py`
46+
- `tests/test_kbit_quantization.py`
47+
48+
## Not in scope
49+
50+
- `quantize_kbit` out parameter (runs once at model load, not on hot path)

tests/test_kbit_quantization.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,3 +1398,69 @@ def test_storage_reduction(self):
13981398
# uint8 should use 4x less storage (ignoring padding)
13991399
assert absmax_e4.element_size() == 1
14001400
assert absmax_f32.element_size() == 4
1401+
1402+
1403+
class TestDequantizeKbitOut:
1404+
"""Tests for dequantize_kbit with pre-allocated out tensor (CUDA graph compatibility)."""
1405+
1406+
@pytest.mark.parametrize("k", [2, 3, 4, 5])
1407+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
1408+
def test_out_matches_normal(self, k, dtype):
1409+
"""Dequant with pre-allocated out should match normal dequant."""
1410+
from bitsandbytes.functional import dequantize_kbit, quantize_kbit
1411+
1412+
n = 1024
1413+
A = torch.randn(n, dtype=dtype, device="cuda")
1414+
packed, absmax, cb = quantize_kbit(A, k=k, absmax_format="e4m4")
1415+
1416+
expected = dequantize_kbit(packed, absmax, cb, k=k, n=n, dtype=dtype)
1417+
1418+
num_blocks = -(n // -32)
1419+
out = torch.empty(num_blocks * 32, device="cuda", dtype=dtype)
1420+
result = dequantize_kbit(packed, absmax, cb, k=k, n=n, dtype=dtype, out=out)
1421+
1422+
assert result.shape == expected.shape
1423+
assert torch.equal(result, expected)
1424+
# Verify it wrote into the provided buffer
1425+
assert result.data_ptr() == out.data_ptr()
1426+
1427+
def test_out_reuse_same_buffer(self):
1428+
"""Calling twice with the same out buffer should produce identical results."""
1429+
from bitsandbytes.functional import dequantize_kbit, quantize_kbit
1430+
1431+
n = 512
1432+
A = torch.randn(n, dtype=torch.float16, device="cuda")
1433+
packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4")
1434+
1435+
num_blocks = -(n // -32)
1436+
out = torch.empty(num_blocks * 32, device="cuda", dtype=torch.float16)
1437+
1438+
r1 = dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out)
1439+
r2 = dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out)
1440+
1441+
assert torch.equal(r1, r2)
1442+
assert r1.data_ptr() == r2.data_ptr()
1443+
1444+
def test_out_wrong_dtype_raises(self):
1445+
"""Passing out with wrong dtype should raise ValueError."""
1446+
from bitsandbytes.functional import dequantize_kbit, quantize_kbit
1447+
1448+
n = 256
1449+
A = torch.randn(n, dtype=torch.float16, device="cuda")
1450+
packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4")
1451+
1452+
out = torch.empty(256, device="cuda", dtype=torch.float32)
1453+
with pytest.raises(ValueError, match="does not match"):
1454+
dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out)
1455+
1456+
def test_out_too_small_raises(self):
1457+
"""Passing out tensor that is too small should raise ValueError."""
1458+
from bitsandbytes.functional import dequantize_kbit, quantize_kbit
1459+
1460+
n = 256
1461+
A = torch.randn(n, dtype=torch.float16, device="cuda")
1462+
packed, absmax, cb = quantize_kbit(A, k=4, absmax_format="e4m4")
1463+
1464+
out = torch.empty(128, device="cuda", dtype=torch.float16)
1465+
with pytest.raises(ValueError, match="need at least"):
1466+
dequantize_kbit(packed, absmax, cb, k=4, n=n, dtype=torch.float16, out=out)

0 commit comments

Comments
 (0)