Skip to content

Commit bff83e6

Browse files
TimDettmersclaude
andcommitted
Add Stage 2 repack kernel, Stage 3 minimal GEMM kernel (76 tests pass)
Stage 2: CUDA repack kernel transforms flat bit-plane packed data into GEMM-tiled layout. Bit-exact match with Python reference for all K values (2,3,4,5) and matrix sizes. Stage 3: Minimal fused kbit dequant + GEMM kernel using m16n8k16 tensor core MMA instructions with fp32 accumulation. Synchronous shared memory loads, 1 block per output tile, no pipeline. Validates tiled addressing, bit-plane extraction, codebook lookup via __shfl_sync, MMA fragment assembly, and output write. Key fix: A-fragment register ordering for m16n8k16 must be {row_lo/k_lo, row_hi/k_lo, row_lo/k_hi, row_hi/k_hi}, NOT the naive {row_lo/k_lo, row_lo/k_hi, row_hi/k_lo, row_hi/k_hi}. This follows from the Turing decomposition into two m16n8k8 operations where a[0],a[1] handle k_lo and a[2],a[3] handle k_hi. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f95a7f2 commit bff83e6

5 files changed

Lines changed: 1440 additions & 0 deletions

File tree

bitsandbytes/_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,52 @@ def _(
475475
)
476476
num_blocks = -(n // -32)
477477
return torch.empty(num_blocks * 32, device=packed.device, dtype=dtype)
478+
479+
480+
# K-bit repack: flat bit-plane layout -> GEMM-tiled layout
481+
482+
torch.library.define(
483+
"bitsandbytes::repack_kbit",
484+
"(Tensor packed_flat, Tensor absmax_flat, int K_dim, int N, int k) -> (Tensor, Tensor)",
485+
)
486+
487+
488+
@register_fake("bitsandbytes::repack_kbit")
489+
def _(packed_flat: torch.Tensor, absmax_flat: torch.Tensor, K_dim: int, N: int, k: int) -> tuple[torch.Tensor, torch.Tensor]:
490+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
491+
TILE_K, TILE_N, BLOCKSIZE = 64, 128, 32
492+
torch._check(N % TILE_N == 0, lambda: f"N ({N}) must be divisible by {TILE_N}")
493+
torch._check(K_dim % BLOCKSIZE == 0, lambda: f"K_dim ({K_dim}) must be divisible by {BLOCKSIZE}")
494+
K_dim_padded = ((K_dim + TILE_K - 1) // TILE_K) * TILE_K
495+
k_tiles = K_dim_padded // TILE_K
496+
n_tiles = N // TILE_N
497+
k_blocks_per_tile = TILE_K // BLOCKSIZE
498+
total_words = k_tiles * n_tiles * TILE_N * k_blocks_per_tile * k
499+
total_absmax = k_tiles * n_tiles * TILE_N * k_blocks_per_tile
500+
packed_tiled = torch.empty(total_words, device=packed_flat.device, dtype=torch.int32)
501+
absmax_tiled = torch.empty(total_absmax, device=packed_flat.device, dtype=torch.uint8)
502+
return packed_tiled, absmax_tiled
503+
504+
505+
# K-bit fused dequant + GEMM: C[M,N] = A[M,K_dim] * W_kbit^T
506+
507+
torch.library.define(
508+
"bitsandbytes::kbit_gemm",
509+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int k) -> Tensor",
510+
)
511+
512+
513+
@register_fake("bitsandbytes::kbit_gemm")
514+
def _(
515+
A: torch.Tensor,
516+
B_packed: torch.Tensor,
517+
B_absmax: torch.Tensor,
518+
codebook: torch.Tensor,
519+
K_dim: int,
520+
N: int,
521+
k: int,
522+
) -> torch.Tensor:
523+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
524+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
525+
M = A.shape[0]
526+
return torch.empty(M, N, device=A.device, dtype=A.dtype)

bitsandbytes/backends/cuda/ops.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,3 +854,80 @@ def _(
854854
)
855855

856856
return out
857+
858+
859+
@register_kernel("bitsandbytes::repack_kbit", "cuda")
860+
def _(
861+
packed_flat: torch.Tensor,
862+
absmax_flat: torch.Tensor,
863+
K_dim: int,
864+
N: int,
865+
k: int,
866+
) -> tuple[torch.Tensor, torch.Tensor]:
867+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
868+
torch._check(packed_flat.dtype == torch.int32, lambda: f"packed_flat must be int32, got {packed_flat.dtype}")
869+
torch._check(absmax_flat.dtype == torch.float32, lambda: f"absmax_flat must be float32, got {absmax_flat.dtype}")
870+
871+
TILE_K, TILE_N, BLOCKSIZE = 64, 128, 32
872+
torch._check(N % TILE_N == 0, lambda: f"N ({N}) must be divisible by {TILE_N}")
873+
torch._check(K_dim % BLOCKSIZE == 0, lambda: f"K_dim ({K_dim}) must be divisible by {BLOCKSIZE}")
874+
875+
K_dim_padded = ((K_dim + TILE_K - 1) // TILE_K) * TILE_K
876+
k_tiles = K_dim_padded // TILE_K
877+
n_tiles = N // TILE_N
878+
k_blocks_per_tile = TILE_K // BLOCKSIZE
879+
total_words = k_tiles * n_tiles * TILE_N * k_blocks_per_tile * k
880+
total_absmax = k_tiles * n_tiles * TILE_N * k_blocks_per_tile
881+
882+
# Zero-fill for padding regions (when K_dim is not multiple of TILE_K)
883+
packed_tiled = torch.zeros(total_words, device=packed_flat.device, dtype=torch.int32)
884+
absmax_tiled = torch.zeros(total_absmax, device=packed_flat.device, dtype=torch.uint8)
885+
886+
with _cuda_device_of(packed_flat):
887+
fn = getattr(lib, f"crepack_kbit_k{k}")
888+
fn(
889+
get_ptr(packed_flat),
890+
get_ptr(absmax_flat),
891+
get_ptr(packed_tiled),
892+
get_ptr(absmax_tiled),
893+
ct.c_int(K_dim),
894+
ct.c_int(N),
895+
)
896+
897+
return packed_tiled, absmax_tiled
898+
899+
900+
@register_kernel("bitsandbytes::kbit_gemm", "cuda")
901+
def _(
902+
A: torch.Tensor,
903+
B_packed: torch.Tensor,
904+
B_absmax: torch.Tensor,
905+
codebook: torch.Tensor,
906+
K_dim: int,
907+
N: int,
908+
k: int,
909+
) -> torch.Tensor:
910+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
911+
torch._check(A.dtype == torch.float16, lambda: f"kbit_gemm currently supports float16 only, got {A.dtype}")
912+
torch._check(B_packed.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed.dtype}")
913+
torch._check(B_absmax.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax.dtype}")
914+
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
915+
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
916+
917+
M = A.shape[0]
918+
C = torch.empty(M, N, device=A.device, dtype=torch.float16)
919+
920+
with _cuda_device_of(A):
921+
fn = getattr(lib, f"ckbit_gemm_fp16_k{k}")
922+
fn(
923+
get_ptr(A),
924+
get_ptr(B_packed),
925+
get_ptr(B_absmax),
926+
get_ptr(codebook),
927+
get_ptr(C),
928+
ct.c_int(M),
929+
ct.c_int(K_dim),
930+
ct.c_int(N),
931+
)
932+
933+
return C

0 commit comments

Comments
 (0)