Skip to content

Commit 2cc0224

Browse files
TimDettmersclaude
andcommitted
Add 4-stage pipeline depth on datacenter GPUs for MMA kernels
- MMA kernel (kbit_gemm_prod): pipeline depth 2→4 stages on datacenter GPUs via constexpr NUM_STAGES conditional on BNB_DATACENTER_GPU. Consumer GPUs keep 2-stage double-buffered pipeline. - Grouped MMA kernel (kbit_grouped_gemm_prod): same 4-stage pipeline. - Add pipelineNumStages() runtime helper for host-side shared memory allocation (returns 4 on sm_90/sm_100, 2 on consumer). - Add cudaFuncSetAttribute call when 4-stage shmem exceeds 48KB limit (occurs for large M_BLOCKS with k>=4). - Generalize pre-fill loop to fill NUM_STAGES-1 tiles instead of hardcoding 1 tile. - Adjust L2 prefetch to prefetch beyond the pipeline window. H100 effect: neutral (±1% on most configs, within noise). Consumer regression: zero (174/174 tests pass, benchmarks unchanged at 2 stages). Infrastructure enables future tuning with larger K_dim shapes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 71c4874 commit 2cc0224

File tree

1 file changed

+85
-24
lines changed

1 file changed

+85
-24
lines changed

csrc/ops.cu

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,13 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
11251125
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
11261126
constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES_VAL + ABS_STAGE_ALIGNED;
11271127

1128+
// Pipeline depth: 4 stages on datacenter GPUs (228KB shmem), 2 on consumer (100KB)
1129+
#if BNB_DATACENTER_GPU
1130+
constexpr int NUM_STAGES = 4;
1131+
#else
1132+
constexpr int NUM_STAGES = 2;
1133+
#endif
1134+
11281135
const int n_tiles = N / TILE_N;
11291136
const int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
11301137
const int tiles_per_split = (k_tiles + k_splits - 1) / k_splits;
@@ -1138,7 +1145,7 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
11381145
const int tid = lane_id % 4;
11391146
const int warp_n_base = warp_id * COLS_PER_WARP;
11401147

1141-
// Double-buffered shared memory
1148+
// Multi-stage shared memory (NUM_STAGES stages)
11421149
extern __shared__ char smem[];
11431150
auto sh_a = [&](int stage) -> scalar_t* { return reinterpret_cast<scalar_t*>(smem + stage * STAGE_BYTES); };
11441151
auto sh_b = [&](int stage) -> unsigned int* {
@@ -1307,22 +1314,31 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
13071314
}
13081315
};
13091316

1310-
// Pipeline: double-buffered cp.async
1311-
fetch_tile(0, kt_start);
1312-
cp_async_fence();
1317+
// Pipeline: NUM_STAGES-deep cp.async (2 on consumer, 4 on datacenter)
1318+
// Pre-fill first (NUM_STAGES - 1) tiles
1319+
{
1320+
int prefill_end = kt_start + NUM_STAGES - 1;
1321+
if (prefill_end > kt_end)
1322+
prefill_end = kt_end;
1323+
for (int pf = kt_start; pf < prefill_end; pf++) {
1324+
fetch_tile((pf - kt_start) % NUM_STAGES, pf);
1325+
cp_async_fence();
1326+
}
1327+
}
13131328

13141329
for (int kt = kt_start; kt < kt_end; kt++) {
1315-
int cur = (kt - kt_start) % 2;
1316-
if (kt + 1 < kt_end) {
1317-
fetch_tile((kt + 1 - kt_start) % 2, kt + 1);
1330+
int cur = (kt - kt_start) % NUM_STAGES;
1331+
int fetch_kt = kt + NUM_STAGES - 1;
1332+
if (fetch_kt < kt_end) {
1333+
fetch_tile((fetch_kt - kt_start) % NUM_STAGES, fetch_kt);
13181334
cp_async_fence();
1319-
// L2 prefetch for tile kt+2 (warms L2 before next fetch_tile issues cp.async)
1320-
if (kt + 2 < kt_end) {
1321-
const int pf_tile = (kt + 2) * n_tiles + n_tile;
1335+
// L2 prefetch for tile beyond the pipeline
1336+
if (fetch_kt + 1 < kt_end) {
1337+
const int pf_tile = (fetch_kt + 1) * n_tiles + n_tile;
13221338
prefetch_l2(B_packed + pf_tile * B_STAGE_WORDS);
13231339
prefetch_l2(B_absmax + pf_tile * ABS_STAGE_ELEMS);
13241340
}
1325-
cp_async_wait<1>();
1341+
cp_async_wait<NUM_STAGES - 1>();
13261342
} else {
13271343
cp_async_wait<0>();
13281344
}
@@ -1392,6 +1408,19 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
13921408
} // end persistent work loop
13931409
}
13941410

1411+
// Pipeline stage count: 4 on datacenter GPUs (more shmem), 2 on consumer.
1412+
static int pipelineNumStages() {
1413+
static int cached = -1;
1414+
if (cached < 0) {
1415+
int major = 0, minor = 0;
1416+
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
1417+
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
1418+
int sm = major * 10 + minor;
1419+
cached = (sm == 90 || sm == 100) ? 4 : 2;
1420+
}
1421+
return cached;
1422+
}
1423+
13951424
// Production GEMM launcher — persistent kernel with auto k_splits
13961425
template <int K, int MB, int TN = 128, typename scalar_t = half, typename ABSMAX_T = unsigned char>
13971426
static void kbitGemmProdLaunch(
@@ -1437,7 +1466,16 @@ static void kbitGemmProdLaunch(
14371466
int grid_size = (k_splits == 1) ? total_work : min(target_blocks, total_work);
14381467

14391468
dim3 block(BLOCK_DIM);
1440-
int smem_size = 2 * STAGE_BYTES;
1469+
int num_stages = pipelineNumStages();
1470+
int smem_size = num_stages * STAGE_BYTES;
1471+
1472+
// If shared memory exceeds default 48KB limit, increase it
1473+
if (smem_size > 48 * 1024) {
1474+
cudaFuncSetAttribute(
1475+
kbit_gemm_prod<K, MB, TN, scalar_t, ABSMAX_T>,
1476+
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size
1477+
);
1478+
}
14411479

14421480
kbit_gemm_prod<K, MB, TN, scalar_t, ABSMAX_T><<<grid_size, block, smem_size, stream>>>(
14431481
A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters, M, K_dim, N, k_splits, total_work
@@ -1538,6 +1576,13 @@ __global__ void kbit_grouped_gemm_prod(
15381576
constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15) & ~15;
15391577
constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES_VAL + ABS_STAGE_ALIGNED;
15401578

1579+
// Pipeline depth: 4 stages on datacenter GPUs, 2 on consumer
1580+
#if BNB_DATACENTER_GPU
1581+
constexpr int NUM_STAGES = 4;
1582+
#else
1583+
constexpr int NUM_STAGES = 2;
1584+
#endif
1585+
15411586
const int n_tiles = N / TILE_N;
15421587
const int k_tiles = (K_dim + TILE_K - 1) / TILE_K;
15431588
const int tiles_per_split = (k_tiles + k_splits - 1) / k_splits;
@@ -1552,7 +1597,7 @@ __global__ void kbit_grouped_gemm_prod(
15521597
const int tid = lane_id % 4;
15531598
const int warp_n_base = warp_id * (TILE_N / NUM_WARPS);
15541599

1555-
// Double-buffered shared memory
1600+
// Multi-stage shared memory (NUM_STAGES stages)
15561601
extern __shared__ char smem[];
15571602
auto sh_a = [&](int stage) -> scalar_t* { return reinterpret_cast<scalar_t*>(smem + stage * STAGE_BYTES); };
15581603
auto sh_b = [&](int stage) -> unsigned int* {
@@ -1750,22 +1795,30 @@ __global__ void kbit_grouped_gemm_prod(
17501795
}
17511796
};
17521797

1753-
// Pipeline: double-buffered cp.async over this split's k-tile range
1754-
fetch_tile(0, kt_start);
1755-
cp_async_fence();
1798+
// Pipeline: NUM_STAGES-deep cp.async (2 on consumer, 4 on datacenter)
1799+
{
1800+
int prefill_end = kt_start + NUM_STAGES - 1;
1801+
if (prefill_end > kt_end)
1802+
prefill_end = kt_end;
1803+
for (int pf = kt_start; pf < prefill_end; pf++) {
1804+
fetch_tile((pf - kt_start) % NUM_STAGES, pf);
1805+
cp_async_fence();
1806+
}
1807+
}
17561808

17571809
for (int kt = kt_start; kt < kt_end; kt++) {
1758-
int cur = (kt - kt_start) % 2;
1759-
if (kt + 1 < kt_end) {
1760-
fetch_tile((kt - kt_start + 1) % 2, kt + 1);
1810+
int cur = (kt - kt_start) % NUM_STAGES;
1811+
int fetch_kt = kt + NUM_STAGES - 1;
1812+
if (fetch_kt < kt_end) {
1813+
fetch_tile((fetch_kt - kt_start) % NUM_STAGES, fetch_kt);
17611814
cp_async_fence();
1762-
// L2 prefetch for tile kt+2
1763-
if (kt + 2 < kt_end) {
1764-
const int pf_tile = (kt + 2) * n_tiles + n_tile;
1815+
// L2 prefetch for tile beyond the pipeline
1816+
if (fetch_kt + 1 < kt_end) {
1817+
const int pf_tile = (fetch_kt + 1) * n_tiles + n_tile;
17651818
prefetch_l2(B_packed + pf_tile * B_STAGE_WORDS);
17661819
prefetch_l2(B_absmax + pf_tile * ABS_STAGE_ELEMS);
17671820
}
1768-
cp_async_wait<1>();
1821+
cp_async_wait<NUM_STAGES - 1>();
17691822
} else {
17701823
cp_async_wait<0>();
17711824
}
@@ -1883,7 +1936,15 @@ static void kbitGroupedGemmProdLaunch(
18831936
int grid_size = (k_splits == 1) ? min(num_sms, total_work) : min(target_blocks, total_work);
18841937

18851938
dim3 block(BLOCK_DIM);
1886-
int smem_size = 2 * STAGE_BYTES;
1939+
int num_stages = pipelineNumStages();
1940+
int smem_size = num_stages * STAGE_BYTES;
1941+
1942+
if (smem_size > 48 * 1024) {
1943+
cudaFuncSetAttribute(
1944+
kbit_grouped_gemm_prod<K, MB, TN, scalar_t, ABSMAX_T>,
1945+
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size
1946+
);
1947+
}
18871948

18881949
kbit_grouped_gemm_prod<K, MB, TN, scalar_t, ABSMAX_T><<<grid_size, block, smem_size, stream>>>(
18891950
A_concat, B_packed_all, B_absmax_all, codebook, C_concat, C_workspace, tile_counters, expert_offsets, K_dim, N,

0 commit comments

Comments
 (0)