Skip to content

Commit 9d11e85

Browse files
TimDettmersclaude
andcommitted
Add out parameter to kbit_scalar_gemv_tiled for CUDA graph compat
Adds kbit_scalar_gemv_tiled_ op that writes to a pre-allocated output buffer, eliminating the allocate+copy in the kbit_linear dispatch path. The CUDA kernel already accepted an output pointer — this just wires it through the torch.library op layer. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8eadb50 commit 9d11e85

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed

bitsandbytes/_ops.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,31 @@ def _(
783783
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
784784
M = A.shape[0]
785785
return torch.empty(M, N, device=A.device, dtype=A.dtype)
786+
787+
788+
# K-bit scalar GEMV tiled with pre-allocated output (CUDA graph compatible)
789+
790+
torch.library.define(
791+
"bitsandbytes::kbit_scalar_gemv_tiled_",
792+
"(Tensor A, Tensor B_packed_tiled, Tensor B_absmax_tiled, Tensor codebook, int K_dim, int N, int k, "
793+
"Tensor(a!) out) -> Tensor(a!)",
794+
)
795+
796+
797+
@register_fake("bitsandbytes::kbit_scalar_gemv_tiled_")
798+
def _(
799+
A: torch.Tensor,
800+
B_packed_tiled: torch.Tensor,
801+
B_absmax_tiled: torch.Tensor,
802+
codebook: torch.Tensor,
803+
K_dim: int,
804+
N: int,
805+
k: int,
806+
out: torch.Tensor,
807+
) -> torch.Tensor:
808+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
809+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
810+
torch._check(A.shape[0] <= 4, lambda: f"kbit_scalar_gemv_tiled_ supports M<=4, got {A.shape[0]}")
811+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
812+
torch._check(out.dtype == A.dtype, lambda: f"out dtype {out.dtype} must match A dtype {A.dtype}")
813+
return out

bitsandbytes/backends/cuda/ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,3 +1332,39 @@ def _(
13321332
ct.c_int(N),
13331333
)
13341334
return out
1335+
1336+
1337+
@register_kernel("bitsandbytes::kbit_scalar_gemv_tiled_", "cuda")
1338+
def _(
1339+
A: torch.Tensor,
1340+
B_packed_tiled: torch.Tensor,
1341+
B_absmax_tiled: torch.Tensor,
1342+
codebook: torch.Tensor,
1343+
K_dim: int,
1344+
N: int,
1345+
k: int,
1346+
out: torch.Tensor,
1347+
) -> torch.Tensor:
1348+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
1349+
torch._check(
1350+
A.dtype in (torch.float16, torch.bfloat16),
1351+
lambda: f"kbit_scalar_gemv_tiled_ supports float16 and bfloat16, got {A.dtype}",
1352+
)
1353+
1354+
M = A.shape[0]
1355+
dtype_suffix = "fp16" if A.dtype == torch.float16 else "bf16"
1356+
abs_suffix = "_fp16abs" if B_absmax_tiled.dtype == torch.float16 else ""
1357+
1358+
with _cuda_device_of(A):
1359+
fn = getattr(lib, f"ckbit_scalar_gemv_tiled_{dtype_suffix}{abs_suffix}_k{k}")
1360+
fn(
1361+
get_ptr(A),
1362+
get_ptr(B_packed_tiled),
1363+
get_ptr(B_absmax_tiled),
1364+
get_ptr(codebook),
1365+
get_ptr(out),
1366+
ct.c_int(M),
1367+
ct.c_int(K_dim),
1368+
ct.c_int(N),
1369+
)
1370+
return out

bitsandbytes/functional.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,11 +1299,9 @@ def kbit_linear(
12991299
if M <= 4:
13001300
# Scalar GEMV: tiled layout, one column per block
13011301
if out is not None:
1302-
# scalar GEMV doesn't have an out variant for tiled yet,
1303-
# so compute into temp and copy
1304-
result = torch.ops.bitsandbytes.kbit_scalar_gemv_tiled(A, B_packed, B_absmax, codebook, K_dim, N, k)
1305-
out[:M, :N].copy_(result)
1306-
return out[:M]
1302+
return torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
1303+
A, B_packed, B_absmax, codebook, K_dim, N, k, out[:M]
1304+
)
13071305
return torch.ops.bitsandbytes.kbit_scalar_gemv_tiled(A, B_packed, B_absmax, codebook, K_dim, N, k)
13081306

13091307
if M <= 16:

0 commit comments

Comments
 (0)