Skip to content

Commit 79ac4dc

Browse files
TimDettmersclaude
andcommitted
feat: Add VQ MMA GEMM kernel (vq_gemm_prod) for M>4
Persistent kernel with cp.async pipeline, split-K, and codebook in shared memory. Modeled on kbit_gemm_prod with byte-indexed shmem codebook lookup replacing bit-plane extraction + warp shuffle. - p=2: half2[256] shmem codebook (1KB), 2 words per 16-element segment - p=4: split half2[256]*2 shmem codebook (2KB), 1 word per segment - Full registration chain: ops.cu, pythonInterface.cpp, _ops.py, ops.py - Torch ops: vq_gemm_prod (allocating) and vq_gemm_prod_ (CUDA graph) - Correctness verified: max rel_err < 0.001 for all (p,M) combos Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4aecedf commit 79ac4dc

File tree

4 files changed

+676
-0
lines changed

4 files changed

+676
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,64 @@ def _(
943943
return out
944944

945945

946+
# VQ fused dequant + MMA GEMM: codebook-based quantized matmul via tensor cores
947+
948+
torch.library.define(
949+
"bitsandbytes::vq_gemm_prod",
950+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int p, int k_chunks) -> Tensor",
951+
)
952+
953+
954+
@register_fake("bitsandbytes::vq_gemm_prod")
955+
def _(
956+
A: torch.Tensor,
957+
B_packed: torch.Tensor,
958+
B_absmax: torch.Tensor,
959+
codebook: torch.Tensor,
960+
K_dim: int,
961+
N: int,
962+
p: int,
963+
k_chunks: int,
964+
) -> torch.Tensor:
965+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
966+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
967+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
968+
M = A.shape[0]
969+
return torch.empty(M, N, device=A.device, dtype=A.dtype)
970+
971+
972+
# VQ fused dequant + MMA GEMM with pre-allocated output and workspace (CUDA graph compatible)
973+
974+
torch.library.define(
975+
"bitsandbytes::vq_gemm_prod_",
976+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int p, int k_chunks, "
977+
"Tensor(a!) out, Tensor C_workspace, Tensor tile_counters) -> Tensor(a!)",
978+
)
979+
980+
981+
@register_fake("bitsandbytes::vq_gemm_prod_")
982+
def _(
983+
A: torch.Tensor,
984+
B_packed: torch.Tensor,
985+
B_absmax: torch.Tensor,
986+
codebook: torch.Tensor,
987+
K_dim: int,
988+
N: int,
989+
p: int,
990+
k_chunks: int,
991+
out: torch.Tensor,
992+
C_workspace: torch.Tensor,
993+
tile_counters: torch.Tensor,
994+
) -> torch.Tensor:
995+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
996+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
997+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
998+
M = A.shape[0]
999+
torch._check(out.shape == (M, N), lambda: f"out must be [{M}, {N}], got {list(out.shape)}")
1000+
torch._check(out.dtype == A.dtype, lambda: f"out dtype {out.dtype} must match A dtype {A.dtype}")
1001+
return out
1002+
1003+
9461004
# K-bit grouped expert GEMM: batch multiple MoE expert GEMMs into one launch
9471005

9481006
torch.library.define(

bitsandbytes/backends/cuda/ops.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,86 @@ def _(
14701470
return out
14711471

14721472

1473+
def _vq_gemm_prod_impl(A, B_packed, B_absmax, codebook, K_dim, N, p, k_chunks, C, C_workspace, tile_counters):
1474+
dtype_suffix = "fp16" if A.dtype == torch.float16 else "bf16"
1475+
1476+
# Zero workspace and counters (required by atomicAdd accumulation)
1477+
C_workspace.zero_()
1478+
tile_counters.zero_()
1479+
1480+
with _cuda_device_of(A):
1481+
fn = getattr(lib, f"cvq_gemm_prod_{dtype_suffix}_p{p}")
1482+
fn(
1483+
get_ptr(A),
1484+
get_ptr(B_packed),
1485+
get_ptr(B_absmax),
1486+
get_ptr(codebook),
1487+
get_ptr(C),
1488+
get_ptr(C_workspace),
1489+
get_ptr(tile_counters),
1490+
ct.c_int(A.shape[0]),
1491+
ct.c_int(K_dim),
1492+
ct.c_int(N),
1493+
ct.c_int(k_chunks),
1494+
_get_tensor_stream(A),
1495+
)
1496+
1497+
1498+
@register_kernel("bitsandbytes::vq_gemm_prod", "cuda")
1499+
def _(
1500+
A: torch.Tensor,
1501+
B_packed: torch.Tensor,
1502+
B_absmax: torch.Tensor,
1503+
codebook: torch.Tensor,
1504+
K_dim: int,
1505+
N: int,
1506+
p: int,
1507+
k_chunks: int,
1508+
) -> torch.Tensor:
1509+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1510+
torch._check(
1511+
A.dtype in (torch.float16, torch.bfloat16),
1512+
lambda: f"vq_gemm_prod supports float16 and bfloat16, got {A.dtype}",
1513+
)
1514+
1515+
M = A.shape[0]
1516+
C = torch.empty(M, N, device=A.device, dtype=A.dtype)
1517+
1518+
TILE_M = 16
1519+
TILE_N = 64 # worst case (most tiles)
1520+
m_tiles = (M + TILE_M - 1) // TILE_M
1521+
n_tiles = N // TILE_N
1522+
1523+
C_workspace = torch.zeros(M, N, device=A.device, dtype=torch.float32)
1524+
tile_counters = torch.zeros(m_tiles * n_tiles, device=A.device, dtype=torch.int32)
1525+
1526+
_vq_gemm_prod_impl(A, B_packed, B_absmax, codebook, K_dim, N, p, k_chunks, C, C_workspace, tile_counters)
1527+
return C
1528+
1529+
1530+
@register_kernel("bitsandbytes::vq_gemm_prod_", "cuda")
1531+
def _(
1532+
A: torch.Tensor,
1533+
B_packed: torch.Tensor,
1534+
B_absmax: torch.Tensor,
1535+
codebook: torch.Tensor,
1536+
K_dim: int,
1537+
N: int,
1538+
p: int,
1539+
k_chunks: int,
1540+
out: torch.Tensor,
1541+
C_workspace: torch.Tensor,
1542+
tile_counters: torch.Tensor,
1543+
) -> torch.Tensor:
1544+
torch._check(p in (2, 4), lambda: f"p must be 2 or 4, got {p}")
1545+
torch._check(
1546+
A.dtype in (torch.float16, torch.bfloat16),
1547+
lambda: f"vq_gemm_prod_ supports float16 and bfloat16, got {A.dtype}",
1548+
)
1549+
_vq_gemm_prod_impl(A, B_packed, B_absmax, codebook, K_dim, N, p, k_chunks, out, C_workspace, tile_counters)
1550+
return out
1551+
1552+
14731553
def _kbit_grouped_gemm_check(A_concat, B_packed_all, B_absmax_all, codebook, expert_offsets, N, k):
14741554
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
14751555
torch._check(

0 commit comments

Comments
 (0)