Skip to content

Commit 24406d2

Browse files
TimDettmersclaude
andcommitted
Add Stage 6 production kernel with bf16 support (139 tests pass)
New production kernel (kbit_gemm_prod) templates on scalar_t to support both fp16 and bf16 activation/output types. Uses the same split-K architecture as Stage 5 with type-dispatched MMA instructions: - fp16: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 - bf16: mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 Helper structs (ScalarOps, pack_two, mma_m16n8k16) abstract type-specific operations. 8 kernel variants instantiated (4 K values x 2 dtypes). fp16 path matches Stage 5 split-K output bit-for-bit. bf16 path matches Python reference within tolerance for all K values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fdcec9c commit 24406d2

File tree

5 files changed

+611
-0
lines changed

5 files changed

+611
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,3 +573,29 @@ def _(
573573
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
574574
M = A.shape[0]
575575
return torch.empty(M, N, device=A.device, dtype=A.dtype)
576+
577+
578+
# K-bit fused dequant + GEMM (production, Stage 6: fp16 + bf16)
579+
580+
torch.library.define(
581+
"bitsandbytes::kbit_gemm_prod",
582+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int k, int k_chunks) -> Tensor",
583+
)
584+
585+
586+
@register_fake("bitsandbytes::kbit_gemm_prod")
587+
def _(
588+
A: torch.Tensor,
589+
B_packed: torch.Tensor,
590+
B_absmax: torch.Tensor,
591+
codebook: torch.Tensor,
592+
K_dim: int,
593+
N: int,
594+
k: int,
595+
k_chunks: int,
596+
) -> torch.Tensor:
597+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
598+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
599+
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
600+
M = A.shape[0]
601+
return torch.empty(M, N, device=A.device, dtype=A.dtype)

bitsandbytes/backends/cuda/ops.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,3 +1021,61 @@ def _(
10211021
)
10221022

10231023
return C
1024+
1025+
1026+
@register_kernel("bitsandbytes::kbit_gemm_prod", "cuda")
1027+
def _(
1028+
A: torch.Tensor,
1029+
B_packed: torch.Tensor,
1030+
B_absmax: torch.Tensor,
1031+
codebook: torch.Tensor,
1032+
K_dim: int,
1033+
N: int,
1034+
k: int,
1035+
k_chunks: int,
1036+
) -> torch.Tensor:
1037+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
1038+
torch._check(
1039+
A.dtype in (torch.float16, torch.bfloat16),
1040+
lambda: f"kbit_gemm_prod supports float16 and bfloat16, got {A.dtype}",
1041+
)
1042+
torch._check(B_packed.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed.dtype}")
1043+
torch._check(B_absmax.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax.dtype}")
1044+
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
1045+
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
1046+
torch._check(k_chunks >= 1, lambda: f"k_chunks must be >= 1, got {k_chunks}")
1047+
1048+
M = A.shape[0]
1049+
C = torch.empty(M, N, device=A.device, dtype=A.dtype)
1050+
1051+
TILE_M = 16
1052+
TILE_N = 128
1053+
m_tiles = (M + TILE_M - 1) // TILE_M
1054+
n_tiles = N // TILE_N
1055+
1056+
if k_chunks > 1:
1057+
C_workspace = torch.zeros(M, N, device=A.device, dtype=torch.float32)
1058+
tile_counters = torch.zeros(m_tiles * n_tiles, device=A.device, dtype=torch.int32)
1059+
else:
1060+
C_workspace = torch.empty(0, device=A.device, dtype=torch.float32)
1061+
tile_counters = torch.empty(0, device=A.device, dtype=torch.int32)
1062+
1063+
dtype_suffix = "fp16" if A.dtype == torch.float16 else "bf16"
1064+
1065+
with _cuda_device_of(A):
1066+
fn = getattr(lib, f"ckbit_gemm_prod_{dtype_suffix}_k{k}")
1067+
fn(
1068+
get_ptr(A),
1069+
get_ptr(B_packed),
1070+
get_ptr(B_absmax),
1071+
get_ptr(codebook),
1072+
get_ptr(C),
1073+
get_ptr(C_workspace),
1074+
get_ptr(tile_counters),
1075+
ct.c_int(M),
1076+
ct.c_int(K_dim),
1077+
ct.c_int(N),
1078+
ct.c_int(k_chunks),
1079+
)
1080+
1081+
return C

0 commit comments

Comments
 (0)