|
| 1 | +# kbit GEMM Kernel: Optimization Guide |
| 2 | + |
| 3 | +This document catalogs the remaining performance optimizations for the |
| 4 | +production kbit GEMM kernel (`kbit_gemm_prod`). Each optimization is |
| 5 | +described with its expected impact, implementation approach, and testing |
| 6 | +strategy. |
| 7 | + |
| 8 | +The kernel is functionally complete (fp16 + bf16, split-K, ldmatrix with |
| 9 | +swizzle, cp.async double-buffered pipeline, 139 tests passing). The |
| 10 | +remaining work is purely about throughput. |
| 11 | + |
| 12 | +--- |
| 13 | + |
| 14 | +## Current State (Baseline) |
| 15 | + |
| 16 | +**Kernel configuration:** |
| 17 | +- TILE_M = 16 (one MMA M-block per warp) |
| 18 | +- TILE_N = 128 (N_BLOCKS = 2, each warp covers 16 columns) |
| 19 | +- TILE_K = 64 (4 MMA k-sub-tiles of 16) |
| 20 | +- 256 threads = 8 warps, each warp handles the same M rows and a slice of N |
| 21 | +- Double-buffered cp.async pipeline |
| 22 | +- ldmatrix.x4 with XOR bank-conflict swizzle for A tile |
| 23 | + |
| 24 | +**RTX 4090 benchmark (K=4, fp16, k_chunks=1):** |
| 25 | + |
| 26 | +| M | K_dim | N | kbit (us) | cuBLAS (us) | Speedup | |
| 27 | +|---:|------:|------:|----------:|------------:|--------:| |
| 28 | +| 1 | 4096 | 4096 | 109 | 43 | 0.39x | |
| 29 | +| 1 | 4096 | 11008 | 82 | 128 | **1.56x** | |
| 30 | +| 4 | 4096 | 4096 | 92 | 22 | 0.24x | |
| 31 | +| 4 | 4096 | 11008 | 100 | 121 | **1.21x** | |
| 32 | +| 16 | 4096 | 4096 | 149 | 28 | 0.19x | |
| 33 | + |
| 34 | +**Why it's slow for square matrices:** Each thread block computes a |
| 35 | +16x128 output tile. With M=16, only 1 M-tile exists, meaning only |
| 36 | +(N/128) blocks launch. For N=4096, that's 32 blocks on a 128-SM GPU -- |
| 37 | +25% utilization. And each block does very little compute per shared |
| 38 | +memory load because TILE_M=16 means only one MMA row-block per warp. |
| 39 | + |
| 40 | +**Why it wins for M=1 large-N:** The GEMM is memory-bandwidth-bound. |
| 41 | +The kernel reads 4-bit compressed weights (4x less data than fp16 |
| 42 | +cuBLAS), which directly translates to speedup. |
| 43 | + |
| 44 | +--- |
| 45 | + |
| 46 | +## Optimization 1: Multi-M-Block Tiling |
| 47 | + |
| 48 | +**Priority: HIGHEST. This is the single biggest performance lever.** |
| 49 | + |
| 50 | +### The Problem |
| 51 | + |
| 52 | +Currently TILE_M=16. Each warp executes 2 MMA operations per k-sub-tile |
| 53 | +(N_BLOCKS=2). The A fragment is loaded once and used for only 2 MMAs. |
| 54 | +The compute-to-load ratio is low. |
| 55 | + |
| 56 | +### The Fix |
| 57 | + |
| 58 | +Template the kernel on `M_BLOCKS` (1, 2, 3, 4). TILE_M becomes |
| 59 | +`M_BLOCKS * 16`. Each warp handles multiple M-blocks, reusing the same |
| 60 | +B fragment across all of them: |
| 61 | + |
| 62 | +``` |
| 63 | +Current (M_BLOCKS=1): |
| 64 | + Each warp: 1 M-block x 2 N-blocks = 2 MMAs per k-sub-tile |
| 65 | +
|
| 66 | +Target (M_BLOCKS=4): |
| 67 | + Each warp: 4 M-blocks x 2 N-blocks = 8 MMAs per k-sub-tile |
| 68 | +``` |
| 69 | + |
| 70 | +The B fragment (dequantized from bit-planes) is the expensive part -- |
| 71 | +codebook lookup via shuffle, absmax multiply. With M_BLOCKS=4, this cost |
| 72 | +is amortized over 4x more MMA operations. |
| 73 | + |
| 74 | +### Implementation |
| 75 | + |
| 76 | +1. Add `M_BLOCKS` template parameter to `kbit_gemm_prod` |
| 77 | +2. FragC accumulator becomes `float frag_c[M_BLOCKS][N_BLOCKS][4]` |
| 78 | +3. A fragment loading: load `M_BLOCKS` fragments per k-sub-tile (ldmatrix |
| 79 | + for each M-block's 16 rows) |
| 80 | +4. Inner loop: for each B fragment, iterate over M_BLOCKS and issue MMA |
| 81 | +5. A tile in shared memory grows: `M_BLOCKS * 16 * TILE_K * sizeof(scalar_t)` |
| 82 | +6. Output write: iterate over M_BLOCKS for the C tile write |
| 83 | +7. Host-side dispatch selects M_BLOCKS based on M: |
| 84 | + - M <= 16: M_BLOCKS=1 |
| 85 | + - M <= 32: M_BLOCKS=2 |
| 86 | + - M <= 48: M_BLOCKS=3 |
| 87 | + - M >= 49: M_BLOCKS=4 |
| 88 | + |
| 89 | +### Shared Memory Impact |
| 90 | + |
| 91 | +| M_BLOCKS | TILE_M | A tile (bytes) | B tile K=4 | Absmax | Per stage | 2 stages | |
| 92 | +|---------:|-------:|---------------:|-----------:|-------:|----------:|---------:| |
| 93 | +| 1 | 16 | 2,048 | 4,096 | 256 | 6,400 | 12,800 | |
| 94 | +| 2 | 32 | 4,096 | 4,096 | 256 | 8,448 | 16,896 | |
| 95 | +| 4 | 64 | 8,192 | 4,096 | 256 | 12,544 | 25,088 | |
| 96 | + |
| 97 | +All fit within RTX 4090's 100 KB limit. Even M_BLOCKS=4 with 4 pipeline |
| 98 | +stages would use ~50 KB. |
| 99 | + |
| 100 | +### Register Impact |
| 101 | + |
| 102 | +FragC grows from 2*4 = 8 floats to M_BLOCKS*2*4 = 32 floats for M_BLOCKS=4. |
| 103 | +FragA grows from 4 uint32 to M_BLOCKS*4 = 16 uint32. Total registers ~50-60, |
| 104 | +well within the 255 limit. |
| 105 | + |
| 106 | +### Expected Speedup |
| 107 | + |
| 108 | +For M=4, K_dim=4096, N=4096 with M_BLOCKS=4: each block does 4x more compute |
| 109 | +per B tile load. Since the kernel is currently B-load-limited for these sizes, |
| 110 | +expect roughly **2-3x improvement** (not full 4x due to diminishing returns |
| 111 | +from A tile growth). |
| 112 | + |
| 113 | +### Test Strategy |
| 114 | + |
| 115 | +- M_BLOCKS=1 must produce identical output to the current kernel (bit-exact) |
| 116 | +- M_BLOCKS=2,3,4 must match Python reference within existing tolerance |
| 117 | +- Test partial M-tiles: M=5 with M_BLOCKS=4 (TILE_M=64, only 5 rows valid) |
| 118 | + |
| 119 | +--- |
| 120 | + |
| 121 | +## Optimization 2: Larger N_BLOCKS per Warp |
| 122 | + |
| 123 | +**Priority: HIGH. Complements multi-M-block.** |
| 124 | + |
| 125 | +### The Problem |
| 126 | + |
| 127 | +Currently N_BLOCKS=2, so each warp covers 16 of the 128 tile columns. |
| 128 | +With 8 warps, that's 8*16 = 128 columns (full tile). But each warp |
| 129 | +only issues 2 MMA ops per k-sub-tile per M-block. |
| 130 | + |
| 131 | +### The Fix |
| 132 | + |
| 133 | +Increase N_BLOCKS to 4 (each warp covers 32 columns). Then 4 warps |
| 134 | +cover the full TILE_N=128. The remaining 4 warps cover additional M |
| 135 | +rows (for the 2-warps-along-M x 4-warps-along-N layout from the |
| 136 | +design doc). |
| 137 | + |
| 138 | +### Warp Layout |
| 139 | + |
| 140 | +The design doc specifies for TILE_M=64, TILE_N=128: |
| 141 | + |
| 142 | +``` |
| 143 | +2 warps along M (each handles 32 rows = 2 M-blocks) |
| 144 | +x 4 warps along N (each handles 32 cols = 4 N-blocks) |
| 145 | += 8 warps total |
| 146 | +
|
| 147 | +Each warp: 2 M-blocks x 4 N-blocks = 8 MMAs per k-sub-tile |
| 148 | +With TILE_K=64 (4 k-sub-tiles): 32 MMAs per warp per K-tile |
| 149 | +``` |
| 150 | + |
| 151 | +This is the target configuration. Combined with multi-M-block, it gives |
| 152 | +each warp 4x more compute than the current kernel. |
| 153 | + |
| 154 | +### Implementation |
| 155 | + |
| 156 | +1. Change N_BLOCKS to 4 |
| 157 | +2. Change warp-to-tile mapping: `warp_m = warp_id / 4`, `warp_n = warp_id % 4` |
| 158 | +3. Each warp handles M-blocks `[warp_m * M_BLOCKS_PER_WARP ... (warp_m+1) * M_BLOCKS_PER_WARP - 1]` |
| 159 | + and N-blocks `[warp_n * 4 ... warp_n * 4 + 3]` |
| 160 | +4. Fragment accumulators: `frag_c[M_BLOCKS_PER_WARP][4][4]` |
| 161 | + |
| 162 | +### Expected Speedup |
| 163 | + |
| 164 | +Combined with multi-M-block: each thread block does **8x** more compute |
| 165 | +per B tile load compared to current (4x from M, 2x from N). For M>=4 |
| 166 | +square matrices, expect the kernel to **match or beat cuBLAS**. |
| 167 | + |
| 168 | +--- |
| 169 | + |
| 170 | +## Optimization 3: C Output Staging Through Shared Memory |
| 171 | + |
| 172 | +**Priority: MEDIUM. Improves memory write efficiency.** |
| 173 | + |
| 174 | +### The Problem |
| 175 | + |
| 176 | +Currently, each thread writes its FragC values directly to global memory. |
| 177 | +The MMA fragment layout means threads in a warp write to scattered row |
| 178 | +positions: |
| 179 | +- Thread with gid=0 writes rows 0, 8 |
| 180 | +- Thread with gid=1 writes rows 1, 9 |
| 181 | +- etc. |
| 182 | + |
| 183 | +These writes hit different cache lines (each row is N*2 bytes apart), |
| 184 | +causing uncoalesced writes. |
| 185 | + |
| 186 | +### The Fix |
| 187 | + |
| 188 | +After the K-tile loop, stage the output through shared memory: |
| 189 | + |
| 190 | +1. Each warp writes its FragC values to shared memory in the natural |
| 191 | + fragment order (scattered rows, but shmem is fast) |
| 192 | +2. `__syncthreads()` |
| 193 | +3. All threads cooperatively read from shared memory in row-major order |
| 194 | + and write to global memory with coalesced access (consecutive threads |
| 195 | + write consecutive addresses within the same row) |
| 196 | + |
| 197 | +### Shared Memory Reuse |
| 198 | + |
| 199 | +The pipeline's shared memory is no longer needed during the output phase |
| 200 | +(the K-tile loop is done). The C staging area can reuse the pipeline |
| 201 | +buffers. For TILE_M=64, TILE_N=128, the C tile is 64*128*2 = 16 KB in |
| 202 | +fp16, which fits easily in one pipeline stage's allocation. |
| 203 | + |
| 204 | +### Expected Speedup |
| 205 | + |
| 206 | +Moderate. The output write is not on the critical path for large K_dim |
| 207 | +(the K-tile loop dominates). For small K_dim or when the kernel is |
| 208 | +already close to bandwidth-optimal, this can give **5-15% improvement**. |
| 209 | + |
| 210 | +--- |
| 211 | + |
| 212 | +## Optimization 4: Persistent Kernel |
| 213 | + |
| 214 | +**Priority: MEDIUM. Helps SM utilization for small tile counts.** |
| 215 | + |
| 216 | +### The Problem |
| 217 | + |
| 218 | +The current 2D/3D grid launch creates one block per output tile (or per |
| 219 | +split-K chunk). When the number of tiles is less than the GPU's SM count, |
| 220 | +SMs sit idle. |
| 221 | + |
| 222 | +### The Fix |
| 223 | + |
| 224 | +Launch exactly `num_SMs` blocks. Each block loops over assigned work items |
| 225 | +(linearized (m_tile, n_tile, k_chunk) triples). Benefits: |
| 226 | + |
| 227 | +1. **Better utilization:** All SMs are always active |
| 228 | +2. **Accumulator persistence:** When consecutive work items share the same |
| 229 | + output tile, the accumulators stay in registers (no atomicAdd needed) |
| 230 | +3. **First-contributor optimization:** The first block to write a tile does |
| 231 | + a plain store to the fp32 workspace (no need to zero it first). Only |
| 232 | + subsequent contributors use atomicAdd. |
| 233 | + |
| 234 | +### Implementation |
| 235 | + |
| 236 | +See design doc Section 6 for the full design. The key structure: |
| 237 | + |
| 238 | +```cpp |
| 239 | +int total_work = m_tiles * n_tiles * k_chunks; |
| 240 | +int work_per_block = div_ceil(total_work, gridDim.x); |
| 241 | +int my_start = blockIdx.x * work_per_block; |
| 242 | +int my_end = min(my_start + work_per_block, total_work); |
| 243 | + |
| 244 | +int prev_mn = -1; |
| 245 | +for (int work_id = my_start; work_id < my_end; work_id++) { |
| 246 | + int mn_id = work_id / k_chunks; |
| 247 | + int k_chunk_id = work_id % k_chunks; |
| 248 | + if (mn_id != prev_mn) { |
| 249 | + if (prev_mn >= 0) write_output(...); |
| 250 | + zero_accumulators(); |
| 251 | + prev_mn = mn_id; |
| 252 | + } |
| 253 | + process_k_range(k_chunk_id, ...); |
| 254 | +} |
| 255 | +if (prev_mn >= 0) write_output(...); |
| 256 | +``` |
| 257 | + |
| 258 | +### Expected Speedup |
| 259 | + |
| 260 | +Depends on the shape. For shapes where `m_tiles * n_tiles < num_SMs` |
| 261 | +(e.g., M=16, N=4096 on a 128-SM GPU: 1*32=32 tiles), the persistent |
| 262 | +kernel can **2-3x** improve throughput by enabling split-K without the |
| 263 | +atomicAdd overhead. For shapes with many tiles, the benefit is marginal. |
| 264 | + |
| 265 | +--- |
| 266 | + |
| 267 | +## Optimization 5: cp.async for A Tile |
| 268 | + |
| 269 | +**Priority: LOW. Minor improvement.** |
| 270 | + |
| 271 | +### The Problem |
| 272 | + |
| 273 | +Currently A is loaded synchronously (element-by-element) while B and |
| 274 | +absmax use cp.async. A could also use cp.async for better latency hiding. |
| 275 | + |
| 276 | +### The Complication |
| 277 | + |
| 278 | +A needs bounds checking (`gr < M && gc < K_dim`) and XOR swizzle on the |
| 279 | +destination address. cp.async copies from a source address to a destination |
| 280 | +address, so the swizzle can be applied to the destination. But bounds |
| 281 | +checking is harder -- cp.async doesn't support conditional copies. |
| 282 | + |
| 283 | +### Possible Approach |
| 284 | + |
| 285 | +Use `cp.async.cg.shared.global` for the interior of the A tile (rows that |
| 286 | +are guaranteed in-bounds), and synchronous loads only for boundary rows. |
| 287 | +For TILE_M=64 and M=4096, almost all rows are in-bounds. Only the last |
| 288 | +M-tile may have boundary rows. |
| 289 | + |
| 290 | +### Expected Speedup |
| 291 | + |
| 292 | +Small (2-5%). A tile is only 2-8 KB per stage, much smaller than B tile. |
| 293 | +The synchronous load latency is already partially hidden by the pipeline. |
| 294 | + |
| 295 | +--- |
| 296 | + |
| 297 | +## Recommended Implementation Order |
| 298 | + |
| 299 | +1. **Multi-M-block tiling** (Optimization 1) -- biggest impact, enables the |
| 300 | + target warp layout |
| 301 | +2. **Larger N_BLOCKS** (Optimization 2) -- natural companion to multi-M-block, |
| 302 | + together they achieve the design doc's target of 32 MMAs per warp per K-tile |
| 303 | +3. **C output staging** (Optimization 3) -- polish for write efficiency |
| 304 | +4. **Persistent kernel** (Optimization 4) -- improves edge cases |
| 305 | +5. **cp.async for A** (Optimization 5) -- diminishing returns |
| 306 | + |
| 307 | +After optimizations 1+2, re-benchmark. If the kernel matches cuBLAS for |
| 308 | +M=1-32 with large N, the remaining optimizations can be deprioritized in |
| 309 | +favor of integration work (wiring into Linear4bit, auto-tuning k_chunks). |
| 310 | + |
| 311 | +--- |
| 312 | + |
| 313 | +## Integration Work (Not Performance, But Required) |
| 314 | + |
| 315 | +These are not performance optimizations but are needed to ship: |
| 316 | + |
| 317 | +- **Wire into LinearNbit module:** Replace the dequant+cuBLAS path with a |
| 318 | + call to `kbit_gemm_prod` when conditions are met (CUDA, fp16/bf16, |
| 319 | + N % 128 == 0, K_dim % 64 == 0) |
| 320 | +- **Auto-select k_chunks:** Based on M, N, K_dim, and SM count. Formula |
| 321 | + from design doc Section 6.2. |
| 322 | +- **Remove staging kernels:** Clean up Stages 3-5 kernels, keeping only |
| 323 | + the production kernel and the debug MMA test |
| 324 | +- **Lint + PR:** Run ruff/clang-format, merge to main |
0 commit comments