@@ -1125,6 +1125,13 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
11251125 constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
11261126 constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES_VAL + ABS_STAGE_ALIGNED;
11271127
1128+ // Pipeline depth: 4 stages on datacenter GPUs (228KB shmem), 2 on consumer (100KB)
1129+ #if BNB_DATACENTER_GPU
1130+ constexpr int NUM_STAGES = 4 ;
1131+ #else
1132+ constexpr int NUM_STAGES = 2 ;
1133+ #endif
1134+
11281135 const int n_tiles = N / TILE_N;
11291136 const int k_tiles = (K_dim + TILE_K - 1 ) / TILE_K;
11301137 const int tiles_per_split = (k_tiles + k_splits - 1 ) / k_splits;
@@ -1138,7 +1145,7 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
11381145 const int tid = lane_id % 4 ;
11391146 const int warp_n_base = warp_id * COLS_PER_WARP;
11401147
1141- // Double-buffered shared memory
1148+ // Multi-stage shared memory (NUM_STAGES stages)
11421149 extern __shared__ char smem[];
11431150 auto sh_a = [&](int stage) -> scalar_t * { return reinterpret_cast <scalar_t *>(smem + stage * STAGE_BYTES); };
11441151 auto sh_b = [&](int stage) -> unsigned int * {
@@ -1307,22 +1314,31 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
13071314 }
13081315 };
13091316
1310- // Pipeline: double-buffered cp.async
1311- fetch_tile (0 , kt_start);
1312- cp_async_fence ();
1317+ // Pipeline: NUM_STAGES-deep cp.async (2 on consumer, 4 on datacenter)
1318+ // Pre-fill first (NUM_STAGES - 1) tiles
1319+ {
1320+ int prefill_end = kt_start + NUM_STAGES - 1 ;
1321+ if (prefill_end > kt_end)
1322+ prefill_end = kt_end;
1323+ for (int pf = kt_start; pf < prefill_end; pf++) {
1324+ fetch_tile ((pf - kt_start) % NUM_STAGES, pf);
1325+ cp_async_fence ();
1326+ }
1327+ }
13131328
13141329 for (int kt = kt_start; kt < kt_end; kt++) {
1315- int cur = (kt - kt_start) % 2 ;
1316- if (kt + 1 < kt_end) {
1317- fetch_tile ((kt + 1 - kt_start) % 2 , kt + 1 );
1330+ int cur = (kt - kt_start) % NUM_STAGES;
1331+ int fetch_kt = kt + NUM_STAGES - 1 ;
1332+ if (fetch_kt < kt_end) {
1333+ fetch_tile ((fetch_kt - kt_start) % NUM_STAGES, fetch_kt);
13181334 cp_async_fence ();
1319- // L2 prefetch for tile kt+2 (warms L2 before next fetch_tile issues cp.async)
1320- if (kt + 2 < kt_end) {
1321- const int pf_tile = (kt + 2 ) * n_tiles + n_tile;
1335+ // L2 prefetch for tile beyond the pipeline
1336+ if (fetch_kt + 1 < kt_end) {
1337+ const int pf_tile = (fetch_kt + 1 ) * n_tiles + n_tile;
13221338 prefetch_l2 (B_packed + pf_tile * B_STAGE_WORDS);
13231339 prefetch_l2 (B_absmax + pf_tile * ABS_STAGE_ELEMS);
13241340 }
1325- cp_async_wait<1 >();
1341+ cp_async_wait<NUM_STAGES - 1 >();
13261342 } else {
13271343 cp_async_wait<0 >();
13281344 }
@@ -1392,6 +1408,19 @@ __global__ void __launch_bounds__(TILE_N_VAL <= 64 ? 128 : 256, TILE_N_VAL <= 64
13921408 } // end persistent work loop
13931409}
13941410
1411+ // Pipeline stage count: 4 on datacenter GPUs (more shmem), 2 on consumer.
1412+ static int pipelineNumStages () {
1413+ static int cached = -1 ;
1414+ if (cached < 0 ) {
1415+ int major = 0 , minor = 0 ;
1416+ cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, 0 );
1417+ cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, 0 );
1418+ int sm = major * 10 + minor;
1419+ cached = (sm == 90 || sm == 100 ) ? 4 : 2 ;
1420+ }
1421+ return cached;
1422+ }
1423+
13951424// Production GEMM launcher — persistent kernel with auto k_splits
13961425template <int K, int MB, int TN = 128 , typename scalar_t = half, typename ABSMAX_T = unsigned char >
13971426static void kbitGemmProdLaunch (
@@ -1437,7 +1466,16 @@ static void kbitGemmProdLaunch(
14371466 int grid_size = (k_splits == 1 ) ? total_work : min (target_blocks, total_work);
14381467
14391468 dim3 block (BLOCK_DIM);
1440- int smem_size = 2 * STAGE_BYTES;
1469+ int num_stages = pipelineNumStages ();
1470+ int smem_size = num_stages * STAGE_BYTES;
1471+
1472+ // If shared memory exceeds default 48KB limit, increase it
1473+ if (smem_size > 48 * 1024 ) {
1474+ cudaFuncSetAttribute (
1475+ kbit_gemm_prod<K, MB, TN, scalar_t , ABSMAX_T>,
1476+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size
1477+ );
1478+ }
14411479
14421480 kbit_gemm_prod<K, MB, TN, scalar_t , ABSMAX_T><<<grid_size, block, smem_size, stream>>> (
14431481 A, B_packed, B_absmax, codebook, C, C_workspace, tile_counters, M, K_dim, N, k_splits, total_work
@@ -1538,6 +1576,13 @@ __global__ void kbit_grouped_gemm_prod(
15381576 constexpr int ABS_STAGE_ALIGNED = (ABS_STAGE_BYTES + 15 ) & ~15 ;
15391577 constexpr int STAGE_BYTES = A_STAGE_BYTES + B_STAGE_BYTES_VAL + ABS_STAGE_ALIGNED;
15401578
1579+ // Pipeline depth: 4 stages on datacenter GPUs, 2 on consumer
1580+ #if BNB_DATACENTER_GPU
1581+ constexpr int NUM_STAGES = 4 ;
1582+ #else
1583+ constexpr int NUM_STAGES = 2 ;
1584+ #endif
1585+
15411586 const int n_tiles = N / TILE_N;
15421587 const int k_tiles = (K_dim + TILE_K - 1 ) / TILE_K;
15431588 const int tiles_per_split = (k_tiles + k_splits - 1 ) / k_splits;
@@ -1552,7 +1597,7 @@ __global__ void kbit_grouped_gemm_prod(
15521597 const int tid = lane_id % 4 ;
15531598 const int warp_n_base = warp_id * (TILE_N / NUM_WARPS);
15541599
1555- // Double-buffered shared memory
1600+ // Multi-stage shared memory (NUM_STAGES stages)
15561601 extern __shared__ char smem[];
15571602 auto sh_a = [&](int stage) -> scalar_t * { return reinterpret_cast <scalar_t *>(smem + stage * STAGE_BYTES); };
15581603 auto sh_b = [&](int stage) -> unsigned int * {
@@ -1750,22 +1795,30 @@ __global__ void kbit_grouped_gemm_prod(
17501795 }
17511796 };
17521797
1753- // Pipeline: double-buffered cp.async over this split's k-tile range
1754- fetch_tile (0 , kt_start);
1755- cp_async_fence ();
1798+ // Pipeline: NUM_STAGES-deep cp.async (2 on consumer, 4 on datacenter)
1799+ {
1800+ int prefill_end = kt_start + NUM_STAGES - 1 ;
1801+ if (prefill_end > kt_end)
1802+ prefill_end = kt_end;
1803+ for (int pf = kt_start; pf < prefill_end; pf++) {
1804+ fetch_tile ((pf - kt_start) % NUM_STAGES, pf);
1805+ cp_async_fence ();
1806+ }
1807+ }
17561808
17571809 for (int kt = kt_start; kt < kt_end; kt++) {
1758- int cur = (kt - kt_start) % 2 ;
1759- if (kt + 1 < kt_end) {
1760- fetch_tile ((kt - kt_start + 1 ) % 2 , kt + 1 );
1810+ int cur = (kt - kt_start) % NUM_STAGES;
1811+ int fetch_kt = kt + NUM_STAGES - 1 ;
1812+ if (fetch_kt < kt_end) {
1813+ fetch_tile ((fetch_kt - kt_start) % NUM_STAGES, fetch_kt);
17611814 cp_async_fence ();
1762- // L2 prefetch for tile kt+2
1763- if (kt + 2 < kt_end) {
1764- const int pf_tile = (kt + 2 ) * n_tiles + n_tile;
1815+ // L2 prefetch for tile beyond the pipeline
1816+ if (fetch_kt + 1 < kt_end) {
1817+ const int pf_tile = (fetch_kt + 1 ) * n_tiles + n_tile;
17651818 prefetch_l2 (B_packed + pf_tile * B_STAGE_WORDS);
17661819 prefetch_l2 (B_absmax + pf_tile * ABS_STAGE_ELEMS);
17671820 }
1768- cp_async_wait<1 >();
1821+ cp_async_wait<NUM_STAGES - 1 >();
17691822 } else {
17701823 cp_async_wait<0 >();
17711824 }
@@ -1883,7 +1936,15 @@ static void kbitGroupedGemmProdLaunch(
18831936 int grid_size = (k_splits == 1 ) ? min (num_sms, total_work) : min (target_blocks, total_work);
18841937
18851938 dim3 block (BLOCK_DIM);
1886- int smem_size = 2 * STAGE_BYTES;
1939+ int num_stages = pipelineNumStages ();
1940+ int smem_size = num_stages * STAGE_BYTES;
1941+
1942+ if (smem_size > 48 * 1024 ) {
1943+ cudaFuncSetAttribute (
1944+ kbit_grouped_gemm_prod<K, MB, TN, scalar_t , ABSMAX_T>,
1945+ cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size
1946+ );
1947+ }
18871948
18881949 kbit_grouped_gemm_prod<K, MB, TN, scalar_t , ABSMAX_T><<<grid_size, block, smem_size, stream>>> (
18891950 A_concat, B_packed_all, B_absmax_all, codebook, C_concat, C_workspace, tile_counters, expert_offsets, K_dim, N,
0 commit comments