Skip to content

Commit 01f7b5f

Browse files
TimDettmersclaude
andcommitted
feat: Add VQ scalar GEMV kernel with flat layout support
- vq_scalar_gemv<P_VAL, M_VAL, TILED, scalar_t, ABSMAX_T> kernel in ops.cu - Shared memory codebook: p=2 → half2[256] (1KB), p=4 → split half2[256]+half2[256] (2KB) - Inner loop: per-byte iteration, codebook lookup, fp16→fp32 convert, FMA - Supports M=1-4, fp16/bf16, E4M4 and float32 absmax - 64 threads (2 warps), 40 registers M=1, 0 spills, full occupancy - Correctness verified: all (p,dtype,M,K,N) combos pass rtol=1e-2 - Performance: 7.4 us for 2048x5120 M=1 p=2 (vs 2.1 us target - needs optimization) - Full registration chain: ops.cu → pythonInterface.cpp → _ops.py → backends/cuda/ops.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a47ad9f commit 01f7b5f

File tree

4 files changed

+631
-0
lines changed

4 files changed

+631
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,106 @@ def _(
617617
return out
618618

619619

620+
# VQ scalar GEMV: byte-indexed codebook lookup GEMV for M=1-4
621+
622+
torch.library.define(
623+
"bitsandbytes::vq_scalar_gemv",
624+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int p) -> Tensor",
625+
)
626+
627+
628+
@register_fake("bitsandbytes::vq_scalar_gemv")
629+
def _(
630+
A: torch.Tensor,
631+
B_packed: torch.Tensor,
632+
B_absmax: torch.Tensor,
633+
codebook: torch.Tensor,
634+
K_dim: int,
635+
N: int,
636+
p: int,
637+
) -> torch.Tensor:
638+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
639+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
640+
torch._check(A.shape[0] <= 4, lambda: f"vq_scalar_gemv supports M<=4, got {A.shape[0]}")
641+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
642+
M = A.shape[0]
643+
return torch.empty(M, N, device=A.device, dtype=A.dtype)
644+
645+
646+
torch.library.define(
647+
"bitsandbytes::vq_scalar_gemv.out",
648+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int p, Tensor(a!) out) -> ()",
649+
)
650+
651+
652+
@register_fake("bitsandbytes::vq_scalar_gemv.out")
653+
def _(
654+
A: torch.Tensor,
655+
B_packed: torch.Tensor,
656+
B_absmax: torch.Tensor,
657+
codebook: torch.Tensor,
658+
K_dim: int,
659+
N: int,
660+
p: int,
661+
out: torch.Tensor,
662+
) -> None:
663+
pass
664+
665+
666+
# VQ scalar GEMV with tiled B layout
667+
668+
torch.library.define(
669+
"bitsandbytes::vq_scalar_gemv_tiled",
670+
"(Tensor A, Tensor B_packed_tiled, Tensor B_absmax_tiled, Tensor codebook, int K_dim, int N, int p) -> Tensor",
671+
)
672+
673+
674+
@register_fake("bitsandbytes::vq_scalar_gemv_tiled")
675+
def _(
676+
A: torch.Tensor,
677+
B_packed_tiled: torch.Tensor,
678+
B_absmax_tiled: torch.Tensor,
679+
codebook: torch.Tensor,
680+
K_dim: int,
681+
N: int,
682+
p: int,
683+
) -> torch.Tensor:
684+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
685+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
686+
torch._check(A.shape[0] <= 4, lambda: f"vq_scalar_gemv_tiled supports M<=4, got {A.shape[0]}")
687+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
688+
M = A.shape[0]
689+
return torch.empty(M, N, device=A.device, dtype=A.dtype)
690+
691+
692+
# VQ scalar GEMV tiled with pre-allocated output (CUDA graph compatible)
693+
694+
torch.library.define(
695+
"bitsandbytes::vq_scalar_gemv_tiled_",
696+
"(Tensor A, Tensor B_packed_tiled, Tensor B_absmax_tiled, Tensor codebook, int K_dim, int N, int p, "
697+
"Tensor(a!) out) -> Tensor(a!)",
698+
)
699+
700+
701+
@register_fake("bitsandbytes::vq_scalar_gemv_tiled_")
702+
def _(
703+
A: torch.Tensor,
704+
B_packed_tiled: torch.Tensor,
705+
B_absmax_tiled: torch.Tensor,
706+
codebook: torch.Tensor,
707+
K_dim: int,
708+
N: int,
709+
p: int,
710+
out: torch.Tensor,
711+
) -> torch.Tensor:
712+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
713+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
714+
torch._check(A.shape[0] <= 4, lambda: f"vq_scalar_gemv_tiled_ supports M<=4, got {A.shape[0]}")
715+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
716+
torch._check(out.dtype == A.dtype, lambda: f"out dtype {out.dtype} must match A dtype {A.dtype}")
717+
return out
718+
719+
620720
# K-bit repack: flat bit-plane layout -> GEMM-tiled layout
621721

622722
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,113 @@ def _(
10591059
return out
10601060

10611061

1062+
def _vq_scalar_gemv_impl(
1063+
A: torch.Tensor,
1064+
B_packed: torch.Tensor,
1065+
B_absmax: torch.Tensor,
1066+
codebook: torch.Tensor,
1067+
K_dim: int,
1068+
N: int,
1069+
p: int,
1070+
out: torch.Tensor,
1071+
tiled: bool = False,
1072+
) -> None:
1073+
M = A.shape[0]
1074+
dtype_suffix = "fp16" if A.dtype == torch.float16 else "bf16"
1075+
tiled_str = "_tiled" if tiled else ""
1076+
1077+
with _cuda_device_of(A):
1078+
fn = getattr(lib, f"cvq_scalar_gemv{tiled_str}_{dtype_suffix}_p{p}")
1079+
fn(
1080+
get_ptr(A),
1081+
get_ptr(B_packed),
1082+
get_ptr(B_absmax),
1083+
get_ptr(codebook),
1084+
get_ptr(out),
1085+
ct.c_int(M),
1086+
ct.c_int(K_dim),
1087+
ct.c_int(N),
1088+
_get_tensor_stream(A),
1089+
)
1090+
1091+
1092+
@register_kernel("bitsandbytes::vq_scalar_gemv", "cuda")
1093+
def _(
1094+
A: torch.Tensor,
1095+
B_packed: torch.Tensor,
1096+
B_absmax: torch.Tensor,
1097+
codebook: torch.Tensor,
1098+
K_dim: int,
1099+
N: int,
1100+
p: int,
1101+
) -> torch.Tensor:
1102+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1103+
torch._check(
1104+
A.dtype in (torch.float16, torch.bfloat16),
1105+
lambda: f"vq_scalar_gemv supports float16 and bfloat16, got {A.dtype}",
1106+
)
1107+
M = A.shape[0]
1108+
out = torch.empty(M, N, device=A.device, dtype=A.dtype)
1109+
_vq_scalar_gemv_impl(A, B_packed, B_absmax, codebook, K_dim, N, p, out=out)
1110+
return out
1111+
1112+
1113+
@register_kernel("bitsandbytes::vq_scalar_gemv.out", "cuda")
1114+
def _(
1115+
A: torch.Tensor,
1116+
B_packed: torch.Tensor,
1117+
B_absmax: torch.Tensor,
1118+
codebook: torch.Tensor,
1119+
K_dim: int,
1120+
N: int,
1121+
p: int,
1122+
out: torch.Tensor,
1123+
) -> None:
1124+
_vq_scalar_gemv_impl(A, B_packed, B_absmax, codebook, K_dim, N, p, out=out)
1125+
1126+
1127+
@register_kernel("bitsandbytes::vq_scalar_gemv_tiled", "cuda")
1128+
def _(
1129+
A: torch.Tensor,
1130+
B_packed_tiled: torch.Tensor,
1131+
B_absmax_tiled: torch.Tensor,
1132+
codebook: torch.Tensor,
1133+
K_dim: int,
1134+
N: int,
1135+
p: int,
1136+
) -> torch.Tensor:
1137+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1138+
torch._check(
1139+
A.dtype in (torch.float16, torch.bfloat16),
1140+
lambda: f"vq_scalar_gemv_tiled supports float16 and bfloat16, got {A.dtype}",
1141+
)
1142+
M = A.shape[0]
1143+
out = torch.empty(M, N, device=A.device, dtype=A.dtype)
1144+
_vq_scalar_gemv_impl(A, B_packed_tiled, B_absmax_tiled, codebook, K_dim, N, p, out=out, tiled=True)
1145+
return out
1146+
1147+
1148+
@register_kernel("bitsandbytes::vq_scalar_gemv_tiled_", "cuda")
1149+
def _(
1150+
A: torch.Tensor,
1151+
B_packed_tiled: torch.Tensor,
1152+
B_absmax_tiled: torch.Tensor,
1153+
codebook: torch.Tensor,
1154+
K_dim: int,
1155+
N: int,
1156+
p: int,
1157+
out: torch.Tensor,
1158+
) -> torch.Tensor:
1159+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1160+
torch._check(
1161+
A.dtype in (torch.float16, torch.bfloat16),
1162+
lambda: f"vq_scalar_gemv_tiled_ supports float16 and bfloat16, got {A.dtype}",
1163+
)
1164+
M = A.shape[0]
1165+
_vq_scalar_gemv_impl(A, B_packed_tiled, B_absmax_tiled, codebook, K_dim, N, p, out=out, tiled=True)
1166+
return out
1167+
1168+
10621169
@register_kernel("bitsandbytes::repack_kbit", "cuda")
10631170
def _(
10641171
packed_flat: torch.Tensor,

0 commit comments

Comments
 (0)