Skip to content

Commit 24cf7d1

Browse files
TimDettmersclaude
andcommitted
feat: Update vq_linear dispatch to use MMA kernel for M=5-16
- Route M<=4 to scalar GEMV, M=5-16 to vq_gemm_prod, M>16 to dequant+cuBLAS (matching kbit_linear dispatch pattern) - Update vq_linear_workspace to include C_workspace and tile_counters - Un-skip MMA test stubs, replace with actual vq_gemm_prod tests - All 100 VQ tests pass (50 scalar GEMV + 50 dispatch/MMA) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 79ac4dc commit 24cf7d1

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

bitsandbytes/functional.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,8 @@ def vq_linear(
15601560
15611561
Routes to the optimal kernel based on M (batch dimension):
15621562
- M <= 4: scalar GEMV (tiled layout, shmem codebook lookup)
1563-
- M > 4: dequantize to fp16/bf16 + cuBLAS matmul
1563+
- M <= 16: fused dequant + MMA (tiled layout, tensor core)
1564+
- M > 16: dequantize to fp16/bf16 + cuBLAS matmul
15641565
15651566
All paths read tiled B layout (from repack_vq output).
15661567
@@ -1574,6 +1575,8 @@ def vq_linear(
15741575
N: Output dimension of weight matrix.
15751576
out: Optional pre-allocated output [M, N] for CUDA graph compat.
15761577
workspace: Optional dict with pre-allocated buffers:
1578+
'C_workspace': float32 [M, N] for MMA accumulation
1579+
'tile_counters': int32 [m_tiles * n_tiles] for persistent kernel
15771580
'dequant_buf': fp16/bf16 [N * K_dim] for dequant+matmul path
15781581
15791582
Returns:
@@ -1590,7 +1593,18 @@ def vq_linear(
15901593
)
15911594
return torch.ops.bitsandbytes.vq_scalar_gemv_tiled(A, B_packed, B_absmax, codebook, K_dim, N, p)
15921595

1593-
# M > 4: dequantize tiled VQ to dense + cuBLAS matmul
1596+
if M <= 16:
1597+
# Fused dequant + MMA: tiled layout, tensor core path
1598+
k_chunks = 1 # auto-selected internally by the kernel
1599+
if out is not None and workspace is not None:
1600+
C_workspace = workspace["C_workspace"]
1601+
tile_counters = workspace["tile_counters"]
1602+
return torch.ops.bitsandbytes.vq_gemm_prod_(
1603+
A, B_packed, B_absmax, codebook, K_dim, N, p, k_chunks, out, C_workspace, tile_counters
1604+
)
1605+
return torch.ops.bitsandbytes.vq_gemm_prod(A, B_packed, B_absmax, codebook, K_dim, N, p, k_chunks)
1606+
1607+
# M > 16: dequantize tiled VQ to dense + cuBLAS matmul
15941608
if workspace is not None and "dequant_buf" in workspace:
15951609
dequant_buf = workspace["dequant_buf"]
15961610
torch.ops.bitsandbytes.dequantize_vq_tiled_(
@@ -1619,12 +1633,17 @@ def vq_linear_workspace(M: int, K_dim: int, N: int, p: int, dtype: torch.dtype,
16191633
device: CUDA device.
16201634
16211635
Returns:
1622-
Dict with 'dequant_buf' tensor.
1636+
Dict with 'C_workspace', 'tile_counters', 'dequant_buf' tensors.
16231637
"""
1638+
TILE_M, TILE_N = 16, 64 # worst-case tile sizes for counter allocation
1639+
m_tiles = (M + TILE_M - 1) // TILE_M
1640+
n_tiles = N // TILE_N
16241641
n_total = N * K_dim
16251642
num_blocks = -(n_total // -32)
16261643

16271644
return {
1645+
"C_workspace": torch.zeros(M, N, device=device, dtype=torch.float32),
1646+
"tile_counters": torch.zeros(m_tiles * n_tiles, device=device, dtype=torch.int32),
16281647
"dequant_buf": torch.empty(num_blocks * 32, device=device, dtype=dtype),
16291648
}
16301649

tests/test_kbit_gemm.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,35 @@ def test_vq_linear_preallocated_output(self, p):
12031203

12041204
@pytest.mark.parametrize("p", [2, 4])
12051205
@pytest.mark.parametrize("M", [5, 8, 16, 32])
1206-
@pytest.mark.skip(reason="Task 5 (VQ MMA kernel) not yet implemented")
12071206
def test_vq_mma_kernel(self, p, M):
1208-
"""VQ MMA kernel correctness (placeholder for Task 5)."""
1209-
pass
1207+
"""VQ MMA kernel (vq_gemm_prod) correctness."""
1208+
from bitsandbytes.functional import create_vq_codebook, quantize_vq, repack_vq
1209+
1210+
K_dim, N = 512, 256
1211+
torch.manual_seed(42)
1212+
1213+
W = torch.randn(N, K_dim)
1214+
codebook = create_vq_codebook(p, device="cuda")
1215+
W_gpu = W.half().cuda()
1216+
packed_flat, absmax_flat, _ = quantize_vq(W_gpu, p=p, codebook=codebook)
1217+
packed_tiled, absmax_tiled = repack_vq(packed_flat, absmax_flat, K_dim, N, p=p)
1218+
1219+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
1220+
1221+
C = torch.ops.bitsandbytes.vq_gemm_prod(
1222+
A, packed_tiled, absmax_tiled, codebook, K_dim, N, p, 1,
1223+
)
1224+
1225+
# Reference
1226+
from bitsandbytes.functional import dequantize_vq
1227+
1228+
W_deq = dequantize_vq(packed_flat, absmax_flat, codebook, p=p, n=N * K_dim)
1229+
W_deq = W_deq.reshape(N, K_dim)
1230+
C_ref = (A.float() @ W_deq.float().T).to(A.dtype)
1231+
1232+
diff = (C.float() - C_ref.float()).abs()
1233+
scale = C_ref.float().abs().clamp(min=1.0)
1234+
rel_err = (diff / scale).max().item()
1235+
assert rel_err < 0.10, (
1236+
f"p={p}, M={M}: vq_gemm_prod mismatch. Max rel err: {rel_err:.6f}"
1237+
)

0 commit comments

Comments
 (0)