|
| 1 | +# MMA Kernel Optimization Spec |
| 2 | + |
| 3 | +## Current State |
| 4 | + |
| 5 | +The scalar GEMV kernel (v8) handles M=1-4 efficiently, achieving 3-5x speedup |
| 6 | +over cuBLAS fp16 at M=1. However, at M>=2, cuBLAS switches to tensor core GEMM |
| 7 | +and is 1.2-1.6x faster than our scalar kernel. The existing MMA kernel |
| 8 | +(`kbit_gemm_prod`) is too slow at small M to fill this gap. |
| 9 | + |
| 10 | +### Scalar GEMV v8 (k=4, shape 0: K=2048 N=5120) |
| 11 | + |
| 12 | +| M | us | GB/s | vs cuBLAS fp16 | |
| 13 | +|---|------|------|----------------| |
| 14 | +| 1 | 13.1 | 512 | 3.9x faster | |
| 15 | +| 2 | 14.8 | 450 | 1.2x slower | |
| 16 | +| 3 | 16.6 | 401 | 1.3x slower | |
| 17 | +| 4 | 19.8 | 337 | 1.6x slower | |
| 18 | + |
| 19 | +cuBLAS fp16: ~12.3 us for M=2-4 (tensor cores, flat scaling). |
| 20 | + |
| 21 | +### Target |
| 22 | + |
| 23 | +An MMA-based dequant kernel that beats cuBLAS fp16 for M=2-16 by leveraging |
| 24 | +the 3.2x data compression from k-bit quantization while using tensor cores for |
| 25 | +the multiply-accumulate. Target: **8-10 us for M=2-4** (matching the theoretical |
| 26 | +DRAM minimum of 8.7 us at 75% bandwidth). |
| 27 | + |
| 28 | +--- |
| 29 | + |
| 30 | +## Why the Current MMA Kernel is Slow |
| 31 | + |
| 32 | +Three compounding problems at small M, analyzed for k=4, K=2048, N=5120: |
| 33 | + |
| 34 | +### 1. SM Utilization: 31% |
| 35 | + |
| 36 | +With TILE_N=128, there are only `N/128 = 40` n-tiles. At M<=16, `m_tiles=1`, |
| 37 | +so `total_work = 40`. On 128 SMs (RTX 4090), 88 SMs sit completely idle. |
| 38 | + |
| 39 | +The k_splits heuristic doesn't trigger because B data (5.6 MB) is under the |
| 40 | +24 MB DRAM threshold. Even with aggressive k_splits: |
| 41 | + |
| 42 | +| k_splits | total_work | grid | SM util | |
| 43 | +|----------|-----------|-------|---------| |
| 44 | +| 1 | 40 | 40 | 31% | |
| 45 | +| 2 | 80 | 80 | 62% | |
| 46 | +| 4 | 160 | 128 | 100% | |
| 47 | + |
| 48 | +But k_splits > 1 adds atomicAdd overhead and a __threadfence + tile_counter |
| 49 | +synchronization per work item. |
| 50 | + |
| 51 | +### 2. MMA Compute Waste: 75-94% |
| 52 | + |
| 53 | +`mma.sync.aligned.m16n8k16` is the smallest MMA tile on sm_89. It computes |
| 54 | +16 M-rows regardless of actual M. At M=1, 15/16 rows are zero-padded: |
| 55 | + |
| 56 | +| M | Useful outputs | Total MMA outputs | Utilization | |
| 57 | +|----|---------------|-------------------|-------------| |
| 58 | +| 1 | 128 | 2048 | 6.2% | |
| 59 | +| 2 | 256 | 2048 | 12.5% | |
| 60 | +| 4 | 512 | 2048 | 25.0% | |
| 61 | +| 8 | 1024 | 2048 | 50.0% | |
| 62 | +| 16 | 2048 | 2048 | 100.0% | |
| 63 | + |
| 64 | +This is an inherent hardware limitation — there is no m4n8k16 or m8n8k16 on |
| 65 | +Ada Lovelace. M < 16 always wastes MMA compute. |
| 66 | + |
| 67 | +### 3. A Tile DRAM Waste |
| 68 | + |
| 69 | +Loading TILE_M * TILE_K * 2 = 2048 bytes per A stage, but at M=1 only |
| 70 | +128 bytes are useful (6%). At M=4: 512 bytes useful (25%). This wastes |
| 71 | +DRAM bandwidth and cp.async slots. |
| 72 | + |
| 73 | +### 4. Dequant is the Bottleneck, Not MMA |
| 74 | + |
| 75 | +Per B element, dequant requires: |
| 76 | +- k bit extractions (shift + AND + shift + OR each): ~3k instructions |
| 77 | +- 1 `__shfl_sync` (codebook lookup): 1 instruction |
| 78 | +- 1 scale multiply: 1 instruction |
| 79 | +- Total: ~3k + 2 instructions per element (14 for k=4) |
| 80 | + |
| 81 | +Per TILE_N x TILE_K tile: 128 * 64 = 8192 elements to dequant. |
| 82 | +Each thread dequants 4 elements per iteration (idx0-idx3), so |
| 83 | +8192 / 4 / 32 lanes = 64 iterations per warp. |
| 84 | + |
| 85 | +The MMA instruction (m16n8k16) takes ~8 cycles on tensor cores. |
| 86 | +The dequant to prepare one B fragment takes ~64 scalar instructions. |
| 87 | +**MMA is not the bottleneck — dequant is.** |
| 88 | + |
| 89 | +--- |
| 90 | + |
| 91 | +## Optimization Strategy |
| 92 | + |
| 93 | +### Dispatch Policy |
| 94 | + |
| 95 | +Use the right kernel for each M range: |
| 96 | + |
| 97 | +| M range | Kernel | Rationale | |
| 98 | +|---------|-----------------|----------------------------------------------| |
| 99 | +| 1 | Scalar GEMV v8 | 3-5x faster than cuBLAS, MMA wastes 94% | |
| 100 | +| 2-4 | MMA dequant v2 | Tensor cores amortize dequant, data savings | |
| 101 | +| 5-16 | MMA dequant v2 | Increasing MMA utilization, still data-bound | |
| 102 | +| 17+ | MMA prod (existing) | Full MMA utilization, existing kernel works | |
| 103 | + |
| 104 | +### Architecture: MMA Dequant v2 |
| 105 | + |
| 106 | +Key changes from `kbit_gemm_prod`: |
| 107 | + |
| 108 | +#### A. Reduce TILE_N from 128 to 64 |
| 109 | + |
| 110 | +This is the single most impactful change for SM utilization: |
| 111 | + |
| 112 | +| TILE_N | n_tiles (N=5120) | shmem/stage | Max blocks/SM | Notes | |
| 113 | +|--------|-----------------|-------------|---------------|-----------------| |
| 114 | +| 128 | 40 | 6400 B | 8 | Current, 31% SM | |
| 115 | +| 64 | 80 | 4224 B | 12 | 62% SM at k=1 | |
| 116 | +| 32 | 160 | 3136 B | 16 | 100%+ SM | |
| 117 | + |
| 118 | +TILE_N=64 with k_splits=2 gives 160 work items = 100% SM utilization. |
| 119 | +TILE_N=32 gives 160 tiles without needing k_splits, avoiding atomicAdd overhead. |
| 120 | + |
| 121 | +Recommendation: **TILE_N=64 with k_splits=2** for best balance of SM util |
| 122 | +vs. per-block work granularity. Consider TILE_N=32 as a fallback for |
| 123 | +shapes where N is small. |
| 124 | + |
| 125 | +Block structure at TILE_N=64: |
| 126 | +- 128 threads (4 warps), each warp handles 16 columns (2 MMA N-blocks of 8) |
| 127 | +- Or 256 threads (8 warps), each warp handles 8 columns (1 MMA N-block) |
| 128 | +- Prefer 128 threads: fewer warps = more blocks/SM, better for small M |
| 129 | + |
| 130 | +#### B. Decouple Dequant from MMA via Shared Memory |
| 131 | + |
| 132 | +Current flow (per warp, per k-step): |
| 133 | +``` |
| 134 | +load planes from shmem → bit extract → shuffle → scale → pack frag_b → MMA |
| 135 | +``` |
| 136 | +This serializes dequant and MMA. The tensor cores idle during dequant. |
| 137 | + |
| 138 | +Proposed flow — **dequant-to-shmem**: |
| 139 | +``` |
| 140 | +Phase 1: All threads cooperatively dequant B tile → fp16 values in shmem |
| 141 | +Phase 2: ldmatrix loads dequanted B from shmem → MMA |
| 142 | +``` |
| 143 | + |
| 144 | +Benefits: |
| 145 | +- `ldmatrix` is a single instruction to load a full MMA fragment from shmem |
| 146 | +- MMA pipeline stays full — no scalar dequant in the critical path |
| 147 | +- All threads participate in dequant (better parallelism) |
| 148 | +- Clean double-buffering: dequant tile K+1 while MMA processes tile K |
| 149 | + |
| 150 | +Shmem cost at TILE_N=64: |
| 151 | +- B dequanted: 64 * 64 * 2 = 8192 bytes per stage |
| 152 | +- A: 16 * 64 * 2 = 2048 bytes per stage |
| 153 | +- Total: 10240 bytes/stage, 20480 bytes double-buffered |
| 154 | +- Max 5 blocks/SM (100 KB limit) → 10 warps (128-thread blocks) or |
| 155 | + 20 warps (if 4 warps/block with 5 blocks). Occupancy: 20-42%. |
| 156 | + |
| 157 | +At TILE_N=32: |
| 158 | +- B dequanted: 32 * 64 * 2 = 4096 bytes |
| 159 | +- Total: 6144 bytes/stage, 12288 bytes double-buffered |
| 160 | +- Max 8 blocks/SM → 32 warps = 67% occupancy. Better. |
| 161 | + |
| 162 | +Trade-off: TILE_N=32 has better occupancy but 2x more tiles to process |
| 163 | +and less N-parallelism per block. |
| 164 | + |
| 165 | +#### C. Cooperative Dequant |
| 166 | + |
| 167 | +In the dequant-to-shmem approach, all threads participate in dequanting: |
| 168 | + |
| 169 | +``` |
| 170 | +Elements per tile: TILE_N * TILE_K = 64 * 64 = 4096 (at TILE_N=64) |
| 171 | +Threads per block: 128 |
| 172 | +Elements per thread: 32 |
| 173 | +``` |
| 174 | + |
| 175 | +Each thread: |
| 176 | +1. Loads K_BITS packed uint32 planes from B shmem (already fetched via cp.async) |
| 177 | +2. Extracts bit indices for its assigned elements |
| 178 | +3. Does __shfl_sync for codebook lookup |
| 179 | +4. Multiplies by scale (absmax) |
| 180 | +5. Writes fp16 result to B_dequant shmem |
| 181 | + |
| 182 | +This is essentially the scalar GEMV's inner loop, but writing to shmem |
| 183 | +instead of accumulating. The `__shfl_sync` requires all lanes to participate, |
| 184 | +so threads within a warp must process elements from the same quantization |
| 185 | +block (same codebook lookup pattern). |
| 186 | + |
| 187 | +Thread mapping for dequant: |
| 188 | +- 128 threads process 4096 elements = 128 quant blocks of 32 elements each |
| 189 | +- Thread t handles quant block t (for TILE_K=64, KB_PER_TILE=2: 128 cols * 2 blocks) |
| 190 | +- Each thread dequants 32 elements, writes 32 fp16 values to shmem |
| 191 | + |
| 192 | +After `__syncthreads()`, all threads switch to MMA consumer role. |
| 193 | + |
| 194 | +#### D. Smarter k_splits Heuristic |
| 195 | + |
| 196 | +The current heuristic is too conservative. Replace with: |
| 197 | + |
| 198 | +``` |
| 199 | +mn_tiles = m_tiles * n_tiles |
| 200 | +target_blocks = num_sms // fill all SMs |
| 201 | +
|
| 202 | +if mn_tiles >= target_blocks: |
| 203 | + k_splits = 1 // enough parallelism from M*N tiles |
| 204 | +else: |
| 205 | + k_splits = min(k_tiles, ceil(target_blocks / mn_tiles)) |
| 206 | + k_splits = min(k_splits, 4) // cap to limit atomicAdd overhead |
| 207 | +``` |
| 208 | + |
| 209 | +For M=2, N=5120, TILE_N=64: mn_tiles=80, target=128, k_splits=2, |
| 210 | +total_work=160. All SMs active. |
| 211 | + |
| 212 | +#### E. Avoid A Waste at Small M |
| 213 | + |
| 214 | +At M < TILE_M (=16), most of the A tile is zero-padded. Two approaches: |
| 215 | + |
| 216 | +**Option 1: Guard the cp.async** (current approach, already implemented). |
| 217 | +Only fetch rows 0..M-1. Remaining shmem rows are zeroed cheaply. |
| 218 | +This already works but wastes shmem space. |
| 219 | + |
| 220 | +**Option 2: Dynamic TILE_M.** Use M_BLOCKS=1 (TILE_M=16) always for M<=16, |
| 221 | +and accept the A waste. The A tile is small (2 KB) relative to B (4-8 KB), |
| 222 | +so the waste is tolerable. Not worth the complexity of variable TILE_M. |
| 223 | + |
| 224 | +Recommendation: Keep current approach. A waste is minor. |
| 225 | + |
| 226 | +--- |
| 227 | + |
| 228 | +## Implementation Plan |
| 229 | + |
| 230 | +### Phase 1: TILE_N=64 + Aggressive k_splits |
| 231 | + |
| 232 | +Minimal changes to `kbit_gemm_prod`: |
| 233 | +1. Add a TILE_N=64 variant (template parameter or separate kernel) |
| 234 | +2. Reduce block to 128 threads (4 warps) |
| 235 | +3. Update k_splits heuristic to always fill SMs |
| 236 | +4. Update dispatcher to use TILE_N=64 for M <= 16 |
| 237 | + |
| 238 | +Expected impact: SM utilization 31% → 100%. Estimated 2-3x speedup for |
| 239 | +small M, bringing the MMA kernel to ~15-20 us range. |
| 240 | + |
| 241 | +### Phase 2: Dequant-to-Shmem |
| 242 | + |
| 243 | +Major restructure of the compute loop: |
| 244 | +1. Add B_dequant shmem buffer (TILE_N * TILE_K * 2 bytes per stage) |
| 245 | +2. Split compute_tile into dequant_phase + mma_phase with __syncthreads between |
| 246 | +3. Dequant phase: all threads extract bits, shuffle codebook, write fp16 to shmem |
| 247 | +4. MMA phase: ldmatrix loads B fragments from shmem, runs MMA |
| 248 | +5. Double-buffer: overlap dequant of tile K+1 with MMA of tile K |
| 249 | + |
| 250 | +Expected impact: removes dequant from MMA critical path. Combined with |
| 251 | +Phase 1, estimated 10-14 us for M=2-4 (competitive with cuBLAS 12.3 us). |
| 252 | + |
| 253 | +### Phase 3: Tuning |
| 254 | + |
| 255 | +1. Profile with ncu, identify remaining bottlenecks |
| 256 | +2. Tune TILE_N (32 vs 64) per shape |
| 257 | +3. Tune k_splits cap (2 vs 4) |
| 258 | +4. Consider warp specialization (dedicated dequant vs MMA warps) |
| 259 | +5. Consider persistent kernel for Phase 2 (reuse shmem across tiles) |
| 260 | + |
| 261 | +--- |
| 262 | + |
| 263 | +## Expected Results |
| 264 | + |
| 265 | +| M | Current MMA (est) | Phase 1 (est) | Phase 2 (est) | cuBLAS fp16 | Scalar GEMV v8 | |
| 266 | +|---|-------------------|---------------|---------------|-------------|---------------| |
| 267 | +| 1 | ~40 us | ~20 us | ~15 us | 51.1 us | **13.1 us** | |
| 268 | +| 2 | ~42 us | ~18 us | ~12 us | 12.3 us | 14.8 us | |
| 269 | +| 4 | ~44 us | ~16 us | ~10 us | 12.5 us | 19.8 us | |
| 270 | +| 8 | ~46 us | ~14 us | ~9 us | ~12.5 us | N/A | |
| 271 | +| 16| ~20 us | ~12 us | ~8 us | ~12.5 us | N/A | |
| 272 | + |
| 273 | +At M=1, scalar GEMV v8 remains the best choice. At M>=2, the optimized MMA |
| 274 | +kernel should match or beat cuBLAS while reading 3.2x less data. The crossover |
| 275 | +between scalar GEMV and MMA shifts from M~2 (vs cuBLAS) to M~2 (our own |
| 276 | +kernels), giving the best of both worlds. |
| 277 | + |
| 278 | +## Theoretical Limits |
| 279 | + |
| 280 | +DRAM payload for k=4, K=2048, N=5120 (independent of M for M<=16): |
| 281 | +- B_packed: 5.24 MB, B_absmax: 1.31 MB, A: negligible |
| 282 | +- Total: ~6.6 MB |
| 283 | +- At 100% DRAM peak (1008 GB/s): 6.5 us |
| 284 | +- At 75%: 8.7 us |
| 285 | +- At 50%: 13.0 us |
| 286 | + |
| 287 | +cuBLAS fp16 reads 21.0 MB (3.2x more). Even at 100% DRAM utilization, |
| 288 | +cuBLAS cannot go below 20.8 us for a pure memory-bound GEMV. The reason |
| 289 | +cuBLAS achieves 12.3 us at M=2 is that it switches to a compute-bound |
| 290 | +tensor core GEMM that reuses data in registers/shmem. |
| 291 | + |
| 292 | +Our MMA kernel's advantage: read 6.6 MB instead of 21.0 MB. If we can |
| 293 | +keep the tensor core pipeline fed, the 3.2x data reduction translates |
| 294 | +directly to a 3.2x speed advantage at the DRAM-bound limit. |
0 commit comments