Skip to content

Commit 9b155d3

Browse files
TimDettmersclaude
andcommitted
Add Stage 4 pipelined GEMM kernel with cp.async double-buffering (89 tests pass)
Double-buffered cp.async pipeline overlapping global→shared memory loads with tensor core computation. B tile and absmax use cp.async (contiguous, always in-bounds from repack). A tile uses synchronous loads (small tile, needs M/K_dim bounds checking). Key changes from Stage 3: - 2× shared memory (two pipeline stages) - B tile stored without +1 column padding (enables contiguous cp.async) - cp.async.cg.shared.global for 16-byte copies (L2 cache only) - Prefetch next tile while computing current tile Output is bit-exact identical to Stage 3 for all K values (2,3,4,5) and all tested matrix sizes, confirming the pipeline is a pure performance change with no math impact. 13 new Stage 4 tests: bit-exact match vs Stage 3 across K values, batch sizes (M=1,4,8,16), and matrix dimensions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ad64c98 commit 9b155d3

File tree

5 files changed

+445
-5
lines changed

5 files changed

+445
-5
lines changed

bitsandbytes/_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,27 @@ def _(
524524
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
525525
M = A.shape[0]
526526
return torch.empty(M, N, device=A.device, dtype=A.dtype)
527+
528+
529+
# K-bit fused dequant + GEMM (pipelined, Stage 4)
530+
531+
torch.library.define(
532+
"bitsandbytes::kbit_gemm_pipelined",
533+
"(Tensor A, Tensor B_packed, Tensor B_absmax, Tensor codebook, int K_dim, int N, int k) -> Tensor",
534+
)
535+
536+
537+
@register_fake("bitsandbytes::kbit_gemm_pipelined")
538+
def _(
539+
A: torch.Tensor,
540+
B_packed: torch.Tensor,
541+
B_absmax: torch.Tensor,
542+
codebook: torch.Tensor,
543+
K_dim: int,
544+
N: int,
545+
k: int,
546+
) -> torch.Tensor:
547+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
548+
torch._check(A.dim() == 2 and A.shape[1] == K_dim, lambda: "A must be [M, K_dim]")
549+
M = A.shape[0]
550+
return torch.empty(M, N, device=A.device, dtype=A.dtype)

bitsandbytes/backends/cuda/ops.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,3 +931,39 @@ def _(
931931
)
932932

933933
return C
934+
935+
936+
@register_kernel("bitsandbytes::kbit_gemm_pipelined", "cuda")
937+
def _(
938+
A: torch.Tensor,
939+
B_packed: torch.Tensor,
940+
B_absmax: torch.Tensor,
941+
codebook: torch.Tensor,
942+
K_dim: int,
943+
N: int,
944+
k: int,
945+
) -> torch.Tensor:
946+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
947+
torch._check(A.dtype == torch.float16, lambda: f"kbit_gemm_pipelined supports float16 only, got {A.dtype}")
948+
torch._check(B_packed.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed.dtype}")
949+
torch._check(B_absmax.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax.dtype}")
950+
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
951+
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
952+
953+
M = A.shape[0]
954+
C = torch.empty(M, N, device=A.device, dtype=torch.float16)
955+
956+
with _cuda_device_of(A):
957+
fn = getattr(lib, f"ckbit_gemm_pipelined_fp16_k{k}")
958+
fn(
959+
get_ptr(A),
960+
get_ptr(B_packed),
961+
get_ptr(B_absmax),
962+
get_ptr(codebook),
963+
get_ptr(C),
964+
ct.c_int(M),
965+
ct.c_int(K_dim),
966+
ct.c_int(N),
967+
)
968+
969+
return C

csrc/ops.cu

Lines changed: 276 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,278 @@ void kbitGemmMinimal(
11311131
CUDA_CHECK_RETURN(cudaPeekAtLastError());
11321132
}
11331133

1134+
// ---- Stage 4: Pipelined fused kbit dequant + GEMM kernel ----
1135+
// Double-buffered cp.async pipeline overlapping loads with compute.
1136+
// Same math as Stage 3 but with async global→shared memory copies for B and absmax,
1137+
// and synchronous A loads (small tile, needs bounds checking).
1138+
// B tile stored WITHOUT +1 padding (simpler cp.async, bank conflicts deferred to Stage 6).
1139+
1140+
// cp.async helpers (sm_80+)
1141+
__device__ __forceinline__ void cp_async_cg_16(void* __restrict__ smem, const void* __restrict__ gmem) {
1142+
uint32_t smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(smem));
1143+
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" ::"r"(smem_addr), "l"(gmem));
1144+
}
1145+
1146+
__device__ __forceinline__ void cp_async_fence() {
1147+
asm volatile("cp.async.commit_group;\n" ::);
1148+
}
1149+
1150+
template <int N>
1151+
__device__ __forceinline__ void cp_async_wait() {
1152+
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
1153+
}
1154+
1155+
template <int K_BITS>
1156+
__global__ void kbit_gemm_pipelined(
1157+
const half* __restrict__ A, const unsigned int* __restrict__ B_packed, const unsigned char* __restrict__ B_absmax,
1158+
const float* __restrict__ codebook, half* __restrict__ C, const int M, const int K_dim, const int N
1159+
) {
1160+
constexpr int TILE_M = 16;
1161+
constexpr int TILE_K = 64;
1162+
constexpr int TILE_N = 128;
1163+
constexpr int BS = 32;
1164+
constexpr int KB_PER_TILE = TILE_K / BS; // 2
1165+
constexpr int B_COL_WORDS = KB_PER_TILE * K_BITS; // words per column (no padding)
1166+
constexpr int N_BLOCKS = 2; // 16 cols per warp / 8 cols per MMA
1167+
1168+
// Per-stage sizes in elements
1169+
constexpr int A_STAGE_ELEMS = TILE_M * TILE_K; // half elements
1170+
constexpr int B_STAGE_WORDS = TILE_N * B_COL_WORDS; // uint32 elements
1171+
constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE; // uint8 elements
1172+
1173+
// Per-stage sizes in bytes (all naturally 16-byte aligned)
1174+
constexpr int A_STAGE_BYTES = A_STAGE_ELEMS * sizeof(half);
1175+
constexpr int B_STAGE_BYTES = B_STAGE_WORDS * sizeof(unsigned int);
1176+
// Round absmax up to 16-byte boundary for alignment
1177+
constexpr int ABS_STAGE_BYTES_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
1178+
1179+
constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES + ABS_STAGE_BYTES_ALIGNED;
1180+
1181+
const int n_tile = blockIdx.x;
1182+
const int m_tile = blockIdx.y;
1183+
const int n_tiles = N / TILE_N;
1184+
const int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
1185+
const int warp_id = threadIdx.x / 32;
1186+
const int lane_id = threadIdx.x % 32;
1187+
const int gid = lane_id / 4;
1188+
const int tid = lane_id % 4;
1189+
1190+
const int warp_n_base = warp_id * (TILE_N / 8);
1191+
const int m_base = m_tile * TILE_M;
1192+
1193+
// Double-buffered shared memory: 2 stages
1194+
extern __shared__ char smem[];
1195+
1196+
// Helper lambdas for stage-indexed shared memory pointers
1197+
auto sh_a = [&](int stage) -> half* {
1198+
return reinterpret_cast<half*>(smem + stage * STAGE_BYTES);
1199+
};
1200+
auto sh_b = [&](int stage) -> unsigned int* {
1201+
return reinterpret_cast<unsigned int*>(smem + stage * STAGE_BYTES + A_STAGE_BYTES);
1202+
};
1203+
auto sh_abs = [&](int stage) -> unsigned char* {
1204+
return reinterpret_cast<unsigned char*>(smem + stage * STAGE_BYTES + A_STAGE_BYTES + B_STAGE_BYTES);
1205+
};
1206+
1207+
// Codebook in register
1208+
half cb_h = (lane_id < (1 << K_BITS)) ? __float2half(codebook[lane_id]) : __float2half(0.0f);
1209+
1210+
// Accumulators
1211+
float frag_c[N_BLOCKS][4];
1212+
#pragma unroll
1213+
for (int nb = 0; nb < N_BLOCKS; nb++)
1214+
frag_c[nb][0] = frag_c[nb][1] = frag_c[nb][2] = frag_c[nb][3] = 0.0f;
1215+
1216+
// ---- Tile fetch function (inlined via lambda) ----
1217+
// B and absmax: cp.async (contiguous, always in-bounds from repack)
1218+
// A: synchronous with bounds checking
1219+
auto fetch_tile = [&](int stage, int kt) {
1220+
const int k_base = kt * TILE_K;
1221+
const int tile_idx = kt * n_tiles + n_tile;
1222+
1223+
// B tile: contiguous cp.async (16-byte / int4 granularity)
1224+
const int b_global_base = tile_idx * B_STAGE_WORDS;
1225+
constexpr int B_INT4S = B_STAGE_BYTES / 16;
1226+
const int4* b_src = reinterpret_cast<const int4*>(B_packed + b_global_base);
1227+
int4* b_dst = reinterpret_cast<int4*>(sh_b(stage));
1228+
for (int i = threadIdx.x; i < B_INT4S; i += blockDim.x)
1229+
cp_async_cg_16(&b_dst[i], &b_src[i]);
1230+
1231+
// Absmax tile: contiguous cp.async
1232+
const int abs_global_base = tile_idx * ABS_STAGE_BYTES;
1233+
constexpr int ABS_INT4S = (ABS_STAGE_BYTES + 15) / 16;
1234+
const int4* abs_src = reinterpret_cast<const int4*>(B_absmax + abs_global_base);
1235+
int4* abs_dst = reinterpret_cast<int4*>(sh_abs(stage));
1236+
if (threadIdx.x < ABS_INT4S)
1237+
cp_async_cg_16(&abs_dst[threadIdx.x], &abs_src[threadIdx.x]);
1238+
1239+
// A tile: synchronous with bounds checking
1240+
half* a_dst = sh_a(stage);
1241+
for (int i = threadIdx.x; i < A_STAGE_ELEMS; i += blockDim.x) {
1242+
int row = i / TILE_K;
1243+
int col = i % TILE_K;
1244+
int gr = m_base + row;
1245+
int gc = k_base + col;
1246+
a_dst[row * TILE_K + col] = (gr < M && gc < K_dim) ? A[gr * K_dim + gc] : __float2half(0.0f);
1247+
}
1248+
};
1249+
1250+
// ---- Compute function for one k-tile ----
1251+
auto compute_tile = [&](int stage) {
1252+
half* a_ptr = sh_a(stage);
1253+
unsigned int* b_ptr = sh_b(stage);
1254+
unsigned char* abs_ptr = sh_abs(stage);
1255+
1256+
#pragma unroll
1257+
for (int ks = 0; ks < 4; ks++) {
1258+
const int k_block = ks / 2;
1259+
const int half_idx = ks % 2;
1260+
1261+
// Load A fragment (same as Stage 3)
1262+
uint32_t frag_a[4];
1263+
{
1264+
const int kc0 = ks * 16 + tid * 2;
1265+
const int kc1 = ks * 16 + tid * 2 + 8;
1266+
const int r0 = gid;
1267+
const int r1 = gid + 8;
1268+
half2 h_rlo_klo = __halves2half2(
1269+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0] : __float2half(0.0f),
1270+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc0 + 1] : __float2half(0.0f));
1271+
half2 h_rhi_klo = __halves2half2(
1272+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0] : __float2half(0.0f),
1273+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc0 + 1] : __float2half(0.0f));
1274+
half2 h_rlo_khi = __halves2half2(
1275+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1] : __float2half(0.0f),
1276+
(r0 < TILE_M) ? a_ptr[r0 * TILE_K + kc1 + 1] : __float2half(0.0f));
1277+
half2 h_rhi_khi = __halves2half2(
1278+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1] : __float2half(0.0f),
1279+
(r1 < TILE_M) ? a_ptr[r1 * TILE_K + kc1 + 1] : __float2half(0.0f));
1280+
frag_a[0] = *reinterpret_cast<uint32_t*>(&h_rlo_klo);
1281+
frag_a[1] = *reinterpret_cast<uint32_t*>(&h_rhi_klo);
1282+
frag_a[2] = *reinterpret_cast<uint32_t*>(&h_rlo_khi);
1283+
frag_a[3] = *reinterpret_cast<uint32_t*>(&h_rhi_khi);
1284+
}
1285+
1286+
#pragma unroll
1287+
for (int nb = 0; nb < N_BLOCKS; nb++) {
1288+
int col = warp_n_base + nb * 8 + gid;
1289+
1290+
// B: read from non-padded layout
1291+
unsigned int planes[K_BITS];
1292+
int b_addr = col * B_COL_WORDS + k_block * K_BITS;
1293+
#pragma unroll
1294+
for (int b = 0; b < K_BITS; b++)
1295+
planes[b] = b_ptr[b_addr + b];
1296+
1297+
half scale = __float2half(decode_e4m4_absmax(abs_ptr[col * KB_PER_TILE + k_block]));
1298+
1299+
const int bit_offset = half_idx * 16;
1300+
const int rows[4] = {2 * tid, 2 * tid + 1, 2 * tid + 8, 2 * tid + 9};
1301+
half vals[4];
1302+
#pragma unroll
1303+
for (int r = 0; r < 4; r++) {
1304+
int bit_pos = bit_offset + rows[r];
1305+
int idx = 0;
1306+
#pragma unroll
1307+
for (int b = 0; b < K_BITS; b++)
1308+
idx |= ((planes[b] >> bit_pos) & 1) << b;
1309+
vals[r] = __hmul(__shfl_sync(0xFFFFFFFF, cb_h, idx), scale);
1310+
}
1311+
1312+
uint32_t frag_b[2];
1313+
{
1314+
half2 b0 = __halves2half2(vals[0], vals[1]);
1315+
half2 b1 = __halves2half2(vals[2], vals[3]);
1316+
frag_b[0] = *reinterpret_cast<uint32_t*>(&b0);
1317+
frag_b[1] = *reinterpret_cast<uint32_t*>(&b1);
1318+
}
1319+
1320+
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
1321+
"{%0, %1, %2, %3}, "
1322+
"{%4, %5, %6, %7}, "
1323+
"{%8, %9}, "
1324+
"{%10, %11, %12, %13};\n"
1325+
: "=f"(frag_c[nb][0]), "=f"(frag_c[nb][1]), "=f"(frag_c[nb][2]),
1326+
"=f"(frag_c[nb][3])
1327+
: "r"(frag_a[0]), "r"(frag_a[1]), "r"(frag_a[2]), "r"(frag_a[3]),
1328+
"r"(frag_b[0]), "r"(frag_b[1]),
1329+
"f"(frag_c[nb][0]), "f"(frag_c[nb][1]), "f"(frag_c[nb][2]),
1330+
"f"(frag_c[nb][3]));
1331+
}
1332+
}
1333+
};
1334+
1335+
// ---- Double-buffered pipeline ----
1336+
// Fetch first tile
1337+
fetch_tile(0, 0);
1338+
cp_async_fence();
1339+
1340+
for (int kt = 0; kt < k_tiles; kt++) {
1341+
int cur = kt % 2;
1342+
1343+
// Prefetch next tile into the other buffer
1344+
if (kt + 1 < k_tiles) {
1345+
fetch_tile((kt + 1) % 2, kt + 1);
1346+
cp_async_fence();
1347+
cp_async_wait<1>(); // wait for current tile, allow next pending
1348+
} else {
1349+
cp_async_wait<0>(); // last tile: wait for everything
1350+
}
1351+
__syncthreads();
1352+
1353+
// Compute on current tile
1354+
compute_tile(cur);
1355+
__syncthreads();
1356+
}
1357+
1358+
// ---- Write output (same as Stage 3) ----
1359+
#pragma unroll
1360+
for (int nb = 0; nb < N_BLOCKS; nb++) {
1361+
int c_col = n_tile * TILE_N + warp_n_base + nb * 8 + tid * 2;
1362+
int m_row0 = m_base + gid;
1363+
int m_row1 = m_base + gid + 8;
1364+
if (m_row0 < M) {
1365+
C[m_row0 * N + c_col] = __float2half(frag_c[nb][0]);
1366+
C[m_row0 * N + c_col + 1] = __float2half(frag_c[nb][1]);
1367+
}
1368+
if (m_row1 < M) {
1369+
C[m_row1 * N + c_col] = __float2half(frag_c[nb][2]);
1370+
C[m_row1 * N + c_col + 1] = __float2half(frag_c[nb][3]);
1371+
}
1372+
}
1373+
}
1374+
1375+
// Stage 4 GEMM launcher
1376+
template <int K>
1377+
void kbitGemmPipelined(
1378+
const half* A, const unsigned int* B_packed, const unsigned char* B_absmax, const float* codebook, half* C, int M,
1379+
int K_dim, int N
1380+
) {
1381+
constexpr int TILE_M = 16;
1382+
constexpr int TILE_K = 64;
1383+
constexpr int TILE_N = 128;
1384+
constexpr int BS = 32;
1385+
constexpr int KB_PER_TILE = TILE_K / BS;
1386+
constexpr int B_COL_WORDS = KB_PER_TILE * K;
1387+
1388+
constexpr int A_STAGE_BYTES = TILE_M * TILE_K * sizeof(half);
1389+
constexpr int B_STAGE_BYTES = TILE_N * B_COL_WORDS * sizeof(unsigned int);
1390+
constexpr int ABS_STAGE_BYTES = TILE_N * KB_PER_TILE;
1391+
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
1392+
constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES + ABS_STAGE_ALIGNED;
1393+
1394+
int m_tiles = (M + TILE_M - 1) / TILE_M;
1395+
int n_tiles = N / TILE_N;
1396+
1397+
dim3 grid(n_tiles, m_tiles);
1398+
dim3 block(256);
1399+
1400+
int smem_size = 2 * STAGE_BYTES; // double buffer
1401+
1402+
kbit_gemm_pipelined<K><<<grid, block, smem_size>>>(A, B_packed, B_absmax, codebook, C, M, K_dim, N);
1403+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
1404+
}
1405+
11341406
// ---- Debug: Simple MMA test kernel ----
11351407
// Takes fp16 A[16,16] and fp16 B[16,8] (B stored row-major), outputs fp32 C[16,8].
11361408
__global__ void test_mma_kernel(const half* __restrict__ A, const half* __restrict__ B, float* __restrict__ C) {
@@ -1249,8 +1521,10 @@ INSTANTIATE_KBIT_REPACK(3)
12491521
INSTANTIATE_KBIT_REPACK(4)
12501522
INSTANTIATE_KBIT_REPACK(5)
12511523

1252-
// GEMM instantiations: one per K value (fp16 only for Stage 3)
1253-
#define INSTANTIATE_KBIT_GEMM(K) template void kbitGemmMinimal<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int);
1524+
// GEMM instantiations: one per K value (fp16 only)
1525+
#define INSTANTIATE_KBIT_GEMM(K) \
1526+
template void kbitGemmMinimal<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int); \
1527+
template void kbitGemmPipelined<K>(const half*, const unsigned int*, const unsigned char*, const float*, half*, int, int, int);
12541528

12551529
INSTANTIATE_KBIT_GEMM(2)
12561530
INSTANTIATE_KBIT_GEMM(3)

0 commit comments

Comments
 (0)