Skip to content

Commit fdcec9c

Browse files
TimDettmersclaude
andcommitted
Add Stage 5 split-K GEMM kernel (110 tests pass)
Split-K support allows multiple thread blocks to share an output tile, each handling a subset of k-tiles. Partial sums accumulated via atomicAdd in fp32 workspace, with the last contributor converting fp32->fp16. Grid is 2D (n_tiles, m_tiles) for k_chunks=1 (same as Stage 4) and 3D (n_tiles, m_tiles, k_chunks) for k_chunks>1. k_chunks=1 produces bit-exact output matching Stage 4. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9b155d3 commit fdcec9c

File tree

5 files changed

+500
-1
lines changed

5 files changed

+500
-1
lines changed

bitsandbytes/_ops.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,28 @@ def _(
548548
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
549549
M = A.shape[0]
550550
return torch.empty(M, N, device=A.device, dtype=A.dtype)
551+
552+
553+
# K-bit fused dequant + GEMM (split-K, Stage 5)
554+
555+
torch.library.define(
556+
"bitsandbytes::kbit_gemm_splitk",
557+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int k, int k_chunks) -> Tensor",
558+
)
559+
560+
561+
@register_fake("bitsandbytes::kbit_gemm_splitk")
562+
def _(
563+
A: torch.Tensor,
564+
B_packed: torch.Tensor,
565+
B_absmax: torch.Tensor,
566+
codebook: torch.Tensor,
567+
K_dim: int,
568+
N: int,
569+
k: int,
570+
k_chunks: int,
571+
) -> torch.Tensor:
572+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
573+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
574+
M = A.shape[0]
575+
return torch.empty(M, N, device=A.device, dtype=A.dtype)

bitsandbytes/backends/cuda/ops.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,3 +967,57 @@ def _(
967967
)
968968

969969
return C
970+
971+
972+
@register_kernel("bitsandbytes::kbit_gemm_splitk", "cuda")
973+
def _(
974+
A: torch.Tensor,
975+
B_packed: torch.Tensor,
976+
B_absmax: torch.Tensor,
977+
codebook: torch.Tensor,
978+
K_dim: int,
979+
N: int,
980+
k: int,
981+
k_chunks: int,
982+
) -> torch.Tensor:
983+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
984+
torch._check(A.dtype == torch.float16, lambda: f"kbit_gemm_splitk supports float16 only, got {A.dtype}")
985+
torch._check(B_packed.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed.dtype}")
986+
torch._check(B_absmax.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax.dtype}")
987+
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
988+
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
989+
torch._check(k_chunks >= 1, lambda: f"k_chunks must be >= 1, got {k_chunks}")
990+
991+
M = A.shape[0]
992+
C = torch.empty(M, N, device=A.device, dtype=torch.float16)
993+
994+
TILE_M = 16
995+
TILE_N = 128
996+
m_tiles = (M + TILE_M - 1) // TILE_M
997+
n_tiles = N // TILE_N
998+
999+
# Allocate workspace and tile counters for split-K (k_chunks > 1)
1000+
if k_chunks > 1:
1001+
C_workspace = torch.zeros(M, N, device=A.device, dtype=torch.float32)
1002+
tile_counters = torch.zeros(m_tiles * n_tiles, device=A.device, dtype=torch.int32)
1003+
else:
1004+
C_workspace = torch.empty(0, device=A.device, dtype=torch.float32)
1005+
tile_counters = torch.empty(0, device=A.device, dtype=torch.int32)
1006+
1007+
with _cuda_device_of(A):
1008+
fn = getattr(lib, f"ckbit_gemm_splitk_fp16_k{k}")
1009+
fn(
1010+
get_ptr(A),
1011+
get_ptr(B_packed),
1012+
get_ptr(B_absmax),
1013+
get_ptr(codebook),
1014+
get_ptr(C),
1015+
get_ptr(C_workspace),
1016+
get_ptr(tile_counters),
1017+
ct.c_int(M),
1018+
ct.c_int(K_dim),
1019+
ct.c_int(N),
1020+
ct.c_int(k_chunks),
1021+
)
1022+
1023+
return C

csrc/ops.cu

Lines changed: 293 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1403,6 +1403,297 @@ void kbitGemmPipelined(
14031403
CUDA_CHECK_RETURN(cudaPeekAtLastError());
14041404
}
14051405

1406+
// ---- Stage 5: Split-K fused kbit dequant + GEMM kernel ----
1407+
// Extends Stage 4 with split-K: multiple blocks share an output tile, each handling
1408+
// a subset of k-tiles. Partial sums accumulated via atomicAdd in fp32 workspace.
1409+
// Grid: (n_tiles, m_tiles) for k_chunks=1, (n_tiles, m_tiles, k_chunks) for k_chunks>1.
1410+
1411+
template <int K_BITS>
1412+
__global__ void kbit_gemm_splitk(
1413+
const half* __restrict__ A, const unsigned int* __restrict__ B_packed, const unsigned char* __restrict__ B_absmax,
1414+
const float* __restrict__ codebook, half* __restrict__ C, float* __restrict__ C_workspace,
1415+
int* __restrict__ tile_counters, const int M, const int K_dim, const int N, const int k_chunks
1416+
) {
1417+
constexpr int TILE_M = 16;
1418+
constexpr int TILE_K = 64;
1419+
constexpr int TILE_N = 128;
1420+
constexpr int BS = 32;
1421+
constexpr int KB_PER_TILE = TILE_K / BS;
1422+
constexpr int B_COL_WORDS = KB_PER_TILE * K_BITS;
1423+
constexpr int N_BLOCKS = 2;
1424+
1425+
constexpr int A_STAGE_ELEMS = TILE_M * TILE_K;
1426+
constexpr int B_STAGE_WORDS = TILE_N * B_COL_WORDS;
1427+
constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE;
1428+
1429+
constexpr int A_STAGE_BYTES = A_STAGE_ELEMS * sizeof(half);
1430+
constexpr int B_STAGE_BYTES_VAL = B_STAGE_WORDS * sizeof(unsigned int);
1431+
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
1432+
constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES_VAL + ABS_STAGE_ALIGNED;
1433+
1434+
const int n_tile = blockIdx.x;
1435+
const int m_tile = blockIdx.y;
1436+
const int k_chunk_id = (k_chunks > 1) ? blockIdx.z : 0;
1437+
const int n_tiles = N / TILE_N;
1438+
const int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
1439+
const int tiles_per_chunk = (k_tiles + k_chunks - 1) / k_chunks;
1440+
const int kt_start = k_chunk_id * tiles_per_chunk;
1441+
const int kt_end = min(kt_start + tiles_per_chunk, k_tiles);
1442+
1443+
const int warp_id = threadIdx.x / 32;
1444+
const int lane_id = threadIdx.x % 32;
1445+
const int gid = lane_id / 4;
1446+
const int tid = lane_id % 4;
1447+
const int warp_n_base = warp_id * (TILE_N / 8);
1448+
const int m_base = m_tile * TILE_M;
1449+
1450+
// Double-buffered shared memory
1451+
extern __shared__ char smem[];
1452+
auto sh_a = [&](int stage) -> half* {
1453+
return reinterpret_cast<half*>(smem + stage * STAGE_BYTES);
1454+
};
1455+
auto sh_b = [&](int stage) -> unsigned int* {
1456+
return reinterpret_cast<unsigned int*>(smem + stage * STAGE_BYTES + A_STAGE_BYTES);
1457+
};
1458+
auto sh_abs = [&](int stage) -> unsigned char* {
1459+
return reinterpret_cast<unsigned char*>(smem + stage * STAGE_BYTES + A_STAGE_BYTES + B_STAGE_BYTES_VAL);
1460+
};
1461+
1462+
half cb_h = (lane_id < (1 << K_BITS)) ? __float2half(codebook[lane_id]) : __float2half(0.0f);
1463+
1464+
float frag_c[N_BLOCKS][4];
1465+
#pragma unroll
1466+
for (int nb = 0; nb < N_BLOCKS; nb++)
1467+
frag_c[nb][0] = frag_c[nb][1] = frag_c[nb][2] = frag_c[nb][3] = 0.0f;
1468+
1469+
// Early exit if this chunk has no tiles
1470+
if (kt_start >= k_tiles)
1471+
return;
1472+
1473+
// Fetch tile lambda (same as Stage 4)
1474+
auto fetch_tile = [&](int stage, int kt) {
1475+
const int k_base = kt * TILE_K;
1476+
const int tile_idx = kt * n_tiles + n_tile;
1477+
1478+
const int b_global_base = tile_idx * B_STAGE_WORDS;
1479+
constexpr int B_INT4S = B_STAGE_BYTES_VAL / 16;
1480+
const int4* b_src = reinterpret_cast<const int4*>(B_packed + b_global_base);
1481+
int4* b_dst = reinterpret_cast<int4*>(sh_b(stage));
1482+
for (int i = threadIdx.x; i < B_INT4S; i += blockDim.x)
1483+
cp_async_cg_16(&b_dst[i], &b_src[i]);
1484+
1485+
const int abs_global_base = tile_idx * ABS_STAGE_BYTES;
1486+
constexpr int ABS_INT4S = (ABS_STAGE_BYTES + 15) / 16;
1487+
const int4* abs_src = reinterpret_cast<const int4*>(B_absmax + abs_global_base);
1488+
int4* abs_dst = reinterpret_cast<int4*>(sh_abs(stage));
1489+
if (threadIdx.x < ABS_INT4S)
1490+
cp_async_cg_16(&abs_dst[threadIdx.x], &abs_src[threadIdx.x]);
1491+
1492+
half* a_dst = sh_a(stage);
1493+
for (int i = threadIdx.x; i < A_STAGE_ELEMS; i += blockDim.x) {
1494+
int row = i / TILE_K;
1495+
int col = i % TILE_K;
1496+
int gr = m_base + row;
1497+
int gc = k_base + col;
1498+
a_dst[row * TILE_K + col] = (gr < M && gc < K_dim) ? A[gr * K_dim + gc] : __float2half(0.0f);
1499+
}
1500+
};
1501+
1502+
// Compute tile lambda (same as Stage 4)
1503+
auto compute_tile = [&](int stage) {
1504+
half* a_ptr = sh_a(stage);
1505+
unsigned int* b_ptr = sh_b(stage);
1506+
unsigned char* abs_ptr = sh_abs(stage);
1507+
1508+
#pragma unroll
1509+
for (int ks = 0; ks < 4; ks++) {
1510+
const int k_block = ks / 2;
1511+
const int half_idx = ks % 2;
1512+
1513+
uint32_t frag_a[4];
1514+
{
1515+
const int kc0 = ks * 16 + tid * 2;
1516+
const int kc1 = ks * 16 + tid * 2 + 8;
1517+
const int r0 = gid;
1518+
const int r1 = gid + 8;
1519+
half2 h_rlo_klo = __halves2half2(
1520+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0] : __float2half(0.0f),
1521+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0 + 1] : __float2half(0.0f));
1522+
half2 h_rhi_klo = __halves2half2(
1523+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0] : __float2half(0.0f),
1524+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0 + 1] : __float2half(0.0f));
1525+
half2 h_rlo_khi = __halves2half2(
1526+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1] : __float2half(0.0f),
1527+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1 + 1] : __float2half(0.0f));
1528+
half2 h_rhi_khi = __halves2half2(
1529+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1] : __float2half(0.0f),
1530+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1 + 1] : __float2half(0.0f));
1531+
frag_a[0] = *reinterpret_cast<uint32_t*>(&h_rlo_klo);
1532+
frag_a[1] = *reinterpret_cast<uint32_t*>(&h_rhi_klo);
1533+
frag_a[2] = *reinterpret_cast<uint32_t*>(&h_rlo_khi);
1534+
frag_a[3] = *reinterpret_cast<uint32_t*>(&h_rhi_khi);
1535+
}
1536+
1537+
#pragma unroll
1538+
for (int nb = 0; nb < N_BLOCKS; nb++) {
1539+
int col = warp_n_base + nb * 8 + gid;
1540+
unsigned int planes[K_BITS];
1541+
int b_addr = col * B_COL_WORDS + k_block * K_BITS;
1542+
#pragma unroll
1543+
for (int b = 0; b < K_BITS; b++)
1544+
planes[b] = b_ptr[b_addr + b];
1545+
1546+
half scale = __float2half(decode_e4m4_absmax(abs_ptr[col * KB_PER_TILE + k_block]));
1547+
1548+
const int bit_offset = half_idx * 16;
1549+
const int rows[4] = {2 * tid, 2 * tid + 1, 2 * tid + 8, 2 * tid + 9};
1550+
half vals[4];
1551+
#pragma unroll
1552+
for (int r = 0; r < 4; r++) {
1553+
int bit_pos = bit_offset + rows[r];
1554+
int idx = 0;
1555+
#pragma unroll
1556+
for (int b = 0; b < K_BITS; b++)
1557+
idx |= ((planes[b] >> bit_pos) & 1) << b;
1558+
vals[r] = __hmul(__shfl_sync(0xFFFFFFFF, cb_h, idx), scale);
1559+
}
1560+
1561+
uint32_t frag_b[2];
1562+
{
1563+
half2 b0 = __halves2half2(vals[0], vals[1]);
1564+
half2 b1 = __halves2half2(vals[2], vals[3]);
1565+
frag_b[0] = *reinterpret_cast<uint32_t*>(&b0);
1566+
frag_b[1] = *reinterpret_cast<uint32_t*>(&b1);
1567+
}
1568+
1569+
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
1570+
"{%0, %1, %2, %3}, "
1571+
"{%4, %5, %6, %7}, "
1572+
"{%8, %9}, "
1573+
"{%10, %11, %12, %13};\n"
1574+
: "=f"(frag_c[nb][0]), "=f"(frag_c[nb][1]), "=f"(frag_c[nb][2]),
1575+
"=f"(frag_c[nb][3])
1576+
: "r"(frag_a[0]), "r"(frag_a[1]), "r"(frag_a[2]), "r"(frag_a[3]),
1577+
"r"(frag_b[0]), "r"(frag_b[1]),
1578+
"f"(frag_c[nb][0]), "f"(frag_c[nb][1]), "f"(frag_c[nb][2]),
1579+
"f"(frag_c[nb][3]));
1580+
}
1581+
}
1582+
};
1583+
1584+
// ---- Pipeline over [kt_start, kt_end) ----
1585+
fetch_tile(0, kt_start);
1586+
cp_async_fence();
1587+
1588+
for (int kt = kt_start; kt < kt_end; kt++) {
1589+
int cur = (kt - kt_start) % 2;
1590+
if (kt + 1 < kt_end) {
1591+
fetch_tile((kt + 1 - kt_start) % 2, kt + 1);
1592+
cp_async_fence();
1593+
cp_async_wait<1>();
1594+
} else {
1595+
cp_async_wait<0>();
1596+
}
1597+
__syncthreads();
1598+
compute_tile(cur);
1599+
__syncthreads();
1600+
}
1601+
1602+
// ---- Write output ----
1603+
if (k_chunks == 1) {
1604+
// No split-K: write fp16 directly (same as Stage 4)
1605+
#pragma unroll
1606+
for (int nb = 0; nb < N_BLOCKS; nb++) {
1607+
int c_col = n_tile * TILE_N + warp_n_base + nb * 8 + tid * 2;
1608+
int m_row0 = m_base + gid;
1609+
int m_row1 = m_base + gid + 8;
1610+
if (m_row0 < M) {
1611+
C[m_row0 * N + c_col] = __float2half(frag_c[nb][0]);
1612+
C[m_row0 * N + c_col + 1] = __float2half(frag_c[nb][1]);
1613+
}
1614+
if (m_row1 < M) {
1615+
C[m_row1 * N + c_col] = __float2half(frag_c[nb][2]);
1616+
C[m_row1 * N + c_col + 1] = __float2half(frag_c[nb][3]);
1617+
}
1618+
}
1619+
} else {
1620+
// Split-K: atomicAdd partial sums to fp32 workspace (pre-zeroed by host)
1621+
#pragma unroll
1622+
for (int nb = 0; nb < N_BLOCKS; nb++) {
1623+
int c_col = n_tile * TILE_N + warp_n_base + nb * 8 + tid * 2;
1624+
int m_row0 = m_base + gid;
1625+
int m_row1 = m_base + gid + 8;
1626+
if (m_row0 < M) {
1627+
atomicAdd(&C_workspace[m_row0 * N + c_col], frag_c[nb][0]);
1628+
atomicAdd(&C_workspace[m_row0 * N + c_col + 1], frag_c[nb][1]);
1629+
}
1630+
if (m_row1 < M) {
1631+
atomicAdd(&C_workspace[m_row1 * N + c_col], frag_c[nb][2]);
1632+
atomicAdd(&C_workspace[m_row1 * N + c_col + 1], frag_c[nb][3]);
1633+
}
1634+
}
1635+
1636+
// Ensure all atomicAdds from this block are globally visible
1637+
__threadfence();
1638+
1639+
// Signal completion and check if we're the last contributor
1640+
__shared__ int is_last;
1641+
if (threadIdx.x == 0) {
1642+
int mn_id = m_tile * n_tiles + n_tile;
1643+
int done = atomicAdd(&tile_counters[mn_id], 1);
1644+
is_last = (done == k_chunks - 1) ? 1 : 0;
1645+
}
1646+
__syncthreads();
1647+
1648+
// Last contributor: convert fp32 workspace -> fp16 output for this tile
1649+
if (is_last) {
1650+
for (int i = threadIdx.x; i < TILE_M * TILE_N; i += blockDim.x) {
1651+
int row = m_base + i / TILE_N;
1652+
int col = n_tile * TILE_N + i % TILE_N;
1653+
if (row < M)
1654+
C[row * N + col] = __float2half(C_workspace[row * N + col]);
1655+
}
1656+
}
1657+
}
1658+
}
1659+
1660+
// Stage 5 split-K GEMM launcher
1661+
template <int K>
1662+
void kbitGemmSplitK(
1663+
const half* A, const unsigned int* B_packed, const unsigned char* B_absmax, const float* codebook, half* C,
1664+
float* C_workspace, int* tile_counters, int M, int K_dim, int N, int k_chunks
1665+
) {
1666+
constexpr int TILE_M = 16;
1667+
constexpr int TILE_K = 64;
1668+
constexpr int TILE_N = 128;
1669+
constexpr int BS = 32;
1670+
constexpr int KB_PER_TILE = TILE_K / BS;
1671+
constexpr int B_COL_WORDS = KB_PER_TILE * K;
1672+
1673+
constexpr int A_STAGE_BYTES = TILE_M * TILE_K * sizeof(half);
1674+
constexpr int B_STAGE_BYTES = TILE_N * B_COL_WORDS * sizeof(unsigned int);
1675+
constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE;
1676+
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
1677+
constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES + ABS_STAGE_ALIGNED;
1678+
1679+
int m_tiles = (M + TILE_M - 1) / TILE_M;
1680+
int n_tiles = N / TILE_N;
1681+
1682+
dim3 block(256);
1683+
int smem_size = 2 * STAGE_BYTES;
1684+
1685+
if (k_chunks <= 1) {
1686+
dim3 grid(n_tiles, m_tiles);
1687+
kbit_gemm_splitk<K><<<grid, block, smem_size>>>(
1688+
A, B_packed, B_absmax, codebook, C, nullptr, nullptr, M, K_dim, N, 1);
1689+
} else {
1690+
dim3 grid(n_tiles, m_tiles, k_chunks);
1691+
kbit_gemm_splitk<K><<<grid, block, smem_size>>>(
1692+
A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters, M, K_dim, N, k_chunks);
1693+
}
1694+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
1695+
}
1696+
14061697
// ---- Debug: Simple MMA test kernel ----
14071698
// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
14081699
__global__ void test_mma_kernel(const half* __restrict__ A, const half* __restrict__ B, float* __restrict__ C) {
@@ -1524,7 +1815,8 @@ INSTANTIATE_KBIT_REPACK(5)
15241815
// GEMM instantiations: one per K value (fp16 only)
15251816
#define INSTANTIATE_KBIT_GEMM(K) \
15261817
template void kbitGemmMinimal<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int); \
1527-
template void kbitGemmPipelined<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int);
1818+
template void kbitGemmPipelined<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int); \
1819+
template void kbitGemmSplitK<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, float*, int*, int, int, int, int);
15281820

15291821
INSTANTIATE_KBIT_GEMM(2)
15301822
INSTANTIATE_KBIT_GEMM(3)

0 commit comments

Comments
 (0)