@@ -222,7 +222,7 @@ gives the compiler 8 independent FMA chains for ILP.
222222
223223** Reduction:**
224224- Intra-warp: shuffle reduction (5 steps)
225- - Inter-warp: 2-phase shared memory (32 bytes), single ` __syncthreads `
225+ - Inter-warp: generalized loop over NUM_WARPS partial sums in shared memory
226226- Thread 0 writes M output values to C
227227
228228** Design decisions:**
@@ -235,6 +235,14 @@ gives the compiler 8 independent FMA chains for ILP.
235235| B absmax | float32 | Uses quantize_kbit output directly, no repack |
236236| Inner loop | Vectorized 4x8 | int4 A loads + sub-loop gives ILP without blowing registers |
237237
238+ ** Tiled v2 kernel (experimental, ` kbit_scalar_gemv_tiled_v2 ` ):**
239+ - 128 threads (4 warps), one N-tile (128 columns) per block
240+ - Cooperative cp.async loading of full B tile + absmax into shared memory
241+ - Split-K support with atomicAdd workspace and tile_counters
242+ - ** Not adopted for production** : cooperative tile loading + __ syncthreads overhead
243+ dominates at M=1-2 (each thread uses only 1/128 of loaded tile data).
244+ The per-column kernel is 10-45% faster for M=1. Kept in code for reference.
245+
238246---
239247
240248## 2. MMA dequant kernel (` kbit_gemm_prod ` )
@@ -247,8 +255,9 @@ gives the compiler 8 independent FMA chains for ILP.
247255- TILE_N=64 for M<=16 (128 threads, 4 warps, ` __launch_bounds__(128, 12) ` )
248256- TILE_N=128 for M>16 (256 threads, 8 warps)
249257- TILE_K=64, TILE_M=16* M_BLOCKS (M_BLOCKS=1..4)
250- - Double-buffered cp.async pipeline for A, B, and absmax tiles
258+ - cp.async pipeline for A, B, and absmax tiles (NUM_STAGES: 4 on datacenter, 2 on consumer)
251259- Persistent kernel with split-K when tiles < target SM occupancy
260+ - L2 prefetch hints for tile kt+2 on datacenter GPUs (` prefetch.global.L2 ` )
252261
253262** Data format:**
254263- B_packed: tiled from ` repack_kbit ` — ` [k_tiles * n_tiles * TILE_N * B_COL_WORDS] `
@@ -266,16 +275,24 @@ for each (k_sub, n_block) pair:
266275 mma.sync.aligned.m16n8k16
267276```
268277
269- ** k_splits heuristic (TILE_N=64) :**
278+ ** k_splits heuristic:**
270279```
271- target_blocks = 128 SMs * 4 blocks/SM = 512
272- if mn_tiles < 512:
273- k_splits = min(k_tiles, ceil(512 / mn_tiles))
274- grid = min(512, mn_tiles * k_splits)
280+ # Consumer (RTX 4090, 128 SMs):
281+ TARGET_BLOCKS_PER_SM = 4 (TN=64) or 1 (TN=128)
282+ target_blocks = 128 * TARGET_BLOCKS_PER_SM
283+
284+ # Datacenter (H100, 132 SMs):
285+ TARGET_BLOCKS_PER_SM = 6 (TN=64) or 2 (TN=128)
286+ target_blocks = 132 * TARGET_BLOCKS_PER_SM
287+
288+ k_splits = min(k_tiles, ceil(target_blocks / mn_tiles))
289+ grid = min(target_blocks, mn_tiles * k_splits)
275290```
276291
277- Split-K uses atomicAdd + tile_counters for the last-arriving split to
278- do the final reduction.
292+ Higher targets on datacenter GPUs improve H100 MMA performance by 5-16%
293+ (more concurrent blocks per SM for better latency hiding with 3.35 TB/s
294+ bandwidth). Split-K uses atomicAdd + tile_counters for the last-arriving
295+ split to do the final reduction.
279296
280297** The fundamental constraint on Ada:**
281298` mma.sync ` is synchronous — the warp stalls until the MMA completes
@@ -512,3 +529,26 @@ On Hopper/datacenter-Blackwell, the MMA dequant kernel could be
512529restructured to issue MMA asynchronously while doing ALU dequant in
513530parallel. This would eliminate the 39:1 instruction overhead that
514531limits the current kernel on Ada. That is a separate future effort.
532+
533+ ** Datacenter GPU optimizations (` BNB_DATACENTER_GPU ` ):**
534+
535+ The macro ` BNB_DATACENTER_GPU ` targets sm_90 (H100/H200) and sm_100
536+ (B200/GB200) explicitly. sm_120 (RTX 5090) is consumer despite being
537+ > 900 and must NOT match.
538+
539+ Implemented optimizations (all behind ` #if BNB_DATACENTER_GPU ` ):
540+ 1 . ** L2 prefetch hints** — ` asm("prefetch.global.L2 [%0];" :: "l"(ptr)) `
541+ for tile kt+2 in MMA/grouped MMA pipelines, and next K-block in
542+ scalar GEMV inner loop
543+ 2 . ** 4-stage pipeline** — MMA and grouped MMA kernels use 4-stage cp.async
544+ pipeline (vs 2-stage on consumer). H100 has 228KB shmem vs 100KB.
545+ Neutral effect for current K_dim=2048-5120 (only 32-80 k-tiles); would
546+ help more with larger K.
547+ 3 . ** Higher k_splits targets** — TARGET_BLOCKS_PER_SM increased (TN=64:
548+ 4→6, TN=128: 1→2) for better SM occupancy on H100's 132 SMs. Provides
549+ 5-16% MMA improvement on Qwen3 70B shapes.
550+
551+ Rejected: Scalar GEMV warp count 2→4 (128 threads per block) — harmful
552+ because K_dim=2048 gives only 64 k-blocks for 128 threads, leaving 50%
553+ idle. Skipped: TMA bulk copy — current cp.async does 1 copy per thread
554+ for these tile sizes, so TMA benefit is marginal.
0 commit comments