Skip to content

Commit fe58198

Browse files
TimDettmersclaude
andcommitted
Add MMA kernel optimization spec for M=2-16 range
Analysis of why kbit_gemm_prod is slow at small M (31% SM util, 94% MMA waste, dequant bottleneck) and a two-phase plan to fix it: Phase 1: TILE_N=64 + aggressive k_splits for SM utilization Phase 2: dequant-to-shmem to decouple dequant from MMA pipeline Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent aba14e8 commit fe58198

File tree

1 file changed

+294
-0
lines changed

1 file changed

+294
-0
lines changed

mma_optimizations.md

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# MMA Kernel Optimization Spec
2+
3+
## Current State
4+
5+
The scalar GEMV kernel (v8) handles M=1-4 efficiently, achieving 3-5x speedup
6+
over cuBLAS fp16 at M=1. However, at M>=2, cuBLAS switches to tensor core GEMM
7+
and is 1.2-1.6x faster than our scalar kernel. The existing MMA kernel
8+
(`kbit_gemm_prod`) is too slow at small M to fill this gap.
9+
10+
### Scalar GEMV v8 (k=4, shape 0: K=2048 N=5120)
11+
12+
| M | us | GB/s | vs cuBLAS fp16 |
13+
|---|------|------|----------------|
14+
| 1 | 13.1 | 512 | 3.9x faster |
15+
| 2 | 14.8 | 450 | 1.2x slower |
16+
| 3 | 16.6 | 401 | 1.3x slower |
17+
| 4 | 19.8 | 337 | 1.6x slower |
18+
19+
cuBLAS fp16: ~12.3 us for M=2-4 (tensor cores, flat scaling).
20+
21+
### Target
22+
23+
An MMA-based dequant kernel that beats cuBLAS fp16 for M=2-16 by leveraging
24+
the 3.2x data compression from k-bit quantization while using tensor cores for
25+
the multiply-accumulate. Target: **8-10 us for M=2-4** (matching the theoretical
26+
DRAM minimum of 8.7 us at 75% bandwidth).
27+
28+
---
29+
30+
## Why the Current MMA Kernel is Slow
31+
32+
Three compounding problems at small M, analyzed for k=4, K=2048, N=5120:
33+
34+
### 1. SM Utilization: 31%
35+
36+
With TILE_N=128, there are only `N/128 = 40` n-tiles. At M<=16, `m_tiles=1`,
37+
so `total_work = 40`. On 128 SMs (RTX 4090), 88 SMs sit completely idle.
38+
39+
The k_splits heuristic doesn't trigger because B data (5.6 MB) is under the
40+
24 MB DRAM threshold. Even with aggressive k_splits:
41+
42+
| k_splits | total_work | grid | SM util |
43+
|----------|-----------|-------|---------|
44+
| 1 | 40 | 40 | 31% |
45+
| 2 | 80 | 80 | 62% |
46+
| 4 | 160 | 128 | 100% |
47+
48+
But k_splits > 1 adds atomicAdd overhead and a __threadfence + tile_counter
49+
synchronization per work item.
50+
51+
### 2. MMA Compute Waste: 75-94%
52+
53+
`mma.sync.aligned.m16n8k16` is the smallest MMA tile on sm_89. It computes
54+
16 M-rows regardless of actual M. At M=1, 15/16 rows are zero-padded:
55+
56+
| M | Useful outputs | Total MMA outputs | Utilization |
57+
|----|---------------|-------------------|-------------|
58+
| 1 | 128 | 2048 | 6.2% |
59+
| 2 | 256 | 2048 | 12.5% |
60+
| 4 | 512 | 2048 | 25.0% |
61+
| 8 | 1024 | 2048 | 50.0% |
62+
| 16 | 2048 | 2048 | 100.0% |
63+
64+
This is an inherent hardware limitation — there is no m4n8k16 or m8n8k16 on
65+
Ada Lovelace. M < 16 always wastes MMA compute.
66+
67+
### 3. A Tile DRAM Waste
68+
69+
Loading TILE_M * TILE_K * 2 = 2048 bytes per A stage, but at M=1 only
70+
128 bytes are useful (6%). At M=4: 512 bytes useful (25%). This wastes
71+
DRAM bandwidth and cp.async slots.
72+
73+
### 4. Dequant is the Bottleneck, Not MMA
74+
75+
Per B element, dequant requires:
76+
- k bit extractions (shift + AND + shift + OR each): ~3k instructions
77+
- 1 `__shfl_sync` (codebook lookup): 1 instruction
78+
- 1 scale multiply: 1 instruction
79+
- Total: ~3k + 2 instructions per element (14 for k=4)
80+
81+
Per TILE_N x TILE_K tile: 128 * 64 = 8192 elements to dequant.
82+
Each thread dequants 4 elements per iteration (idx0-idx3), so
83+
8192 / 4 / 32 lanes = 64 iterations per warp.
84+
85+
The MMA instruction (m16n8k16) takes ~8 cycles on tensor cores.
86+
The dequant to prepare one B fragment takes ~64 scalar instructions.
87+
**MMA is not the bottleneck — dequant is.**
88+
89+
---
90+
91+
## Optimization Strategy
92+
93+
### Dispatch Policy
94+
95+
Use the right kernel for each M range:
96+
97+
| M range | Kernel | Rationale |
98+
|---------|-----------------|----------------------------------------------|
99+
| 1 | Scalar GEMV v8 | 3-5x faster than cuBLAS, MMA wastes 94% |
100+
| 2-4 | MMA dequant v2 | Tensor cores amortize dequant, data savings |
101+
| 5-16 | MMA dequant v2 | Increasing MMA utilization, still data-bound |
102+
| 17+ | MMA prod (existing) | Full MMA utilization, existing kernel works |
103+
104+
### Architecture: MMA Dequant v2
105+
106+
Key changes from `kbit_gemm_prod`:
107+
108+
#### A. Reduce TILE_N from 128 to 64
109+
110+
This is the single most impactful change for SM utilization:
111+
112+
| TILE_N | n_tiles (N=5120) | shmem/stage | Max blocks/SM | Notes |
113+
|--------|-----------------|-------------|---------------|-----------------|
114+
| 128 | 40 | 6400 B | 8 | Current, 31% SM |
115+
| 64 | 80 | 4224 B | 12 | 62% SM at k=1 |
116+
| 32 | 160 | 3136 B | 16 | 100%+ SM |
117+
118+
TILE_N=64 with k_splits=2 gives 160 work items = 100% SM utilization.
119+
TILE_N=32 gives 160 tiles without needing k_splits, avoiding atomicAdd overhead.
120+
121+
Recommendation: **TILE_N=64 with k_splits=2** for best balance of SM util
122+
vs. per-block work granularity. Consider TILE_N=32 as a fallback for
123+
shapes where N is small.
124+
125+
Block structure at TILE_N=64:
126+
- 128 threads (4 warps), each warp handles 16 columns (2 MMA N-blocks of 8)
127+
- Or 256 threads (8 warps), each warp handles 8 columns (1 MMA N-block)
128+
- Prefer 128 threads: fewer warps = more blocks/SM, better for small M
129+
130+
#### B. Decouple Dequant from MMA via Shared Memory
131+
132+
Current flow (per warp, per k-step):
133+
```
134+
load planes from shmem → bit extract → shuffle → scale → pack frag_b → MMA
135+
```
136+
This serializes dequant and MMA. The tensor cores idle during dequant.
137+
138+
Proposed flow — **dequant-to-shmem**:
139+
```
140+
Phase 1: All threads cooperatively dequant B tile → fp16 values in shmem
141+
Phase 2: ldmatrix loads dequanted B from shmem → MMA
142+
```
143+
144+
Benefits:
145+
- `ldmatrix` is a single instruction to load a full MMA fragment from shmem
146+
- MMA pipeline stays full — no scalar dequant in the critical path
147+
- All threads participate in dequant (better parallelism)
148+
- Clean double-buffering: dequant tile K+1 while MMA processes tile K
149+
150+
Shmem cost at TILE_N=64:
151+
- B dequanted: 64 * 64 * 2 = 8192 bytes per stage
152+
- A: 16 * 64 * 2 = 2048 bytes per stage
153+
- Total: 10240 bytes/stage, 20480 bytes double-buffered
154+
- Max 5 blocks/SM (100 KB limit) → 10 warps (128-thread blocks) or
155+
20 warps (if 4 warps/block with 5 blocks). Occupancy: 20-42%.
156+
157+
At TILE_N=32:
158+
- B dequanted: 32 * 64 * 2 = 4096 bytes
159+
- Total: 6144 bytes/stage, 12288 bytes double-buffered
160+
- Max 8 blocks/SM → 32 warps = 67% occupancy. Better.
161+
162+
Trade-off: TILE_N=32 has better occupancy but 2x more tiles to process
163+
and less N-parallelism per block.
164+
165+
#### C. Cooperative Dequant
166+
167+
In the dequant-to-shmem approach, all threads participate in dequanting:
168+
169+
```
170+
Elements per tile: TILE_N * TILE_K = 64 * 64 = 4096 (at TILE_N=64)
171+
Threads per block: 128
172+
Elements per thread: 32
173+
```
174+
175+
Each thread:
176+
1. Loads K_BITS packed uint32 planes from B shmem (already fetched via cp.async)
177+
2. Extracts bit indices for its assigned elements
178+
3. Does __shfl_sync for codebook lookup
179+
4. Multiplies by scale (absmax)
180+
5. Writes fp16 result to B_dequant shmem
181+
182+
This is essentially the scalar GEMV's inner loop, but writing to shmem
183+
instead of accumulating. The `__shfl_sync` requires all lanes to participate,
184+
so threads within a warp must process elements from the same quantization
185+
block (same codebook lookup pattern).
186+
187+
Thread mapping for dequant:
188+
- 128 threads process 4096 elements = 128 quant blocks of 32 elements each
189+
- Thread t handles quant block t (for TILE_K=64, KB_PER_TILE=2: 128 cols * 2 blocks)
190+
- Each thread dequants 32 elements, writes 32 fp16 values to shmem
191+
192+
After `__syncthreads()`, all threads switch to MMA consumer role.
193+
194+
#### D. Smarter k_splits Heuristic
195+
196+
The current heuristic is too conservative. Replace with:
197+
198+
```
199+
mn_tiles = m_tiles * n_tiles
200+
target_blocks = num_sms // fill all SMs
201+
202+
if mn_tiles >= target_blocks:
203+
k_splits = 1 // enough parallelism from M*N tiles
204+
else:
205+
k_splits = min(k_tiles, ceil(target_blocks / mn_tiles))
206+
k_splits = min(k_splits, 4) // cap to limit atomicAdd overhead
207+
```
208+
209+
For M=2, N=5120, TILE_N=64: mn_tiles=80, target=128, k_splits=2,
210+
total_work=160. All SMs active.
211+
212+
#### E. Avoid A Waste at Small M
213+
214+
At M < TILE_M (=16), most of the A tile is zero-padded. Two approaches:
215+
216+
**Option 1: Guard the cp.async** (current approach, already implemented).
217+
Only fetch rows 0..M-1. Remaining shmem rows are zeroed cheaply.
218+
This already works but wastes shmem space.
219+
220+
**Option 2: Dynamic TILE_M.** Use M_BLOCKS=1 (TILE_M=16) always for M<=16,
221+
and accept the A waste. The A tile is small (2 KB) relative to B (4-8 KB),
222+
so the waste is tolerable. Not worth the complexity of variable TILE_M.
223+
224+
Recommendation: Keep current approach. A waste is minor.
225+
226+
---
227+
228+
## Implementation Plan
229+
230+
### Phase 1: TILE_N=64 + Aggressive k_splits
231+
232+
Minimal changes to `kbit_gemm_prod`:
233+
1. Add a TILE_N=64 variant (template parameter or separate kernel)
234+
2. Reduce block to 128 threads (4 warps)
235+
3. Update k_splits heuristic to always fill SMs
236+
4. Update dispatcher to use TILE_N=64 for M <= 16
237+
238+
Expected impact: SM utilization 31% → 100%. Estimated 2-3x speedup for
239+
small M, bringing the MMA kernel to ~15-20 us range.
240+
241+
### Phase 2: Dequant-to-Shmem
242+
243+
Major restructure of the compute loop:
244+
1. Add B_dequant shmem buffer (TILE_N * TILE_K * 2 bytes per stage)
245+
2. Split compute_tile into dequant_phase + mma_phase with __syncthreads between
246+
3. Dequant phase: all threads extract bits, shuffle codebook, write fp16 to shmem
247+
4. MMA phase: ldmatrix loads B fragments from shmem, runs MMA
248+
5. Double-buffer: overlap dequant of tile K+1 with MMA of tile K
249+
250+
Expected impact: removes dequant from MMA critical path. Combined with
251+
Phase 1, estimated 10-14 us for M=2-4 (competitive with cuBLAS 12.3 us).
252+
253+
### Phase 3: Tuning
254+
255+
1. Profile with ncu, identify remaining bottlenecks
256+
2. Tune TILE_N (32 vs 64) per shape
257+
3. Tune k_splits cap (2 vs 4)
258+
4. Consider warp specialization (dedicated dequant vs MMA warps)
259+
5. Consider persistent kernel for Phase 2 (reuse shmem across tiles)
260+
261+
---
262+
263+
## Expected Results
264+
265+
| M | Current MMA (est) | Phase 1 (est) | Phase 2 (est) | cuBLAS fp16 | Scalar GEMV v8 |
266+
|---|-------------------|---------------|---------------|-------------|---------------|
267+
| 1 | ~40 us | ~20 us | ~15 us | 51.1 us | **13.1 us** |
268+
| 2 | ~42 us | ~18 us | ~12 us | 12.3 us | 14.8 us |
269+
| 4 | ~44 us | ~16 us | ~10 us | 12.5 us | 19.8 us |
270+
| 8 | ~46 us | ~14 us | ~9 us | ~12.5 us | N/A |
271+
| 16| ~20 us | ~12 us | ~8 us | ~12.5 us | N/A |
272+
273+
At M=1, scalar GEMV v8 remains the best choice. At M>=2, the optimized MMA
274+
kernel should match or beat cuBLAS while reading 3.2x less data. The crossover
275+
between scalar GEMV and MMA shifts from M~2 (vs cuBLAS) to M~2 (our own
276+
kernels), giving the best of both worlds.
277+
278+
## Theoretical Limits
279+
280+
DRAM payload for k=4, K=2048, N=5120 (independent of M for M<=16):
281+
- B_packed: 5.24 MB, B_absmax: 1.31 MB, A: negligible
282+
- Total: ~6.6 MB
283+
- At 100% DRAM peak (1008 GB/s): 6.5 us
284+
- At 75%: 8.7 us
285+
- At 50%: 13.0 us
286+
287+
cuBLAS fp16 reads 21.0 MB (3.2x more). Even at 100% DRAM utilization,
288+
cuBLAS cannot go below 20.8 us for a pure memory-bound GEMV. The reason
289+
cuBLAS achieves 12.3 us at M=2 is that it switches to a compute-bound
290+
tensor core GEMM that reuses data in registers/shmem.
291+
292+
Our MMA kernel's advantage: read 6.6 MB instead of 21.0 MB. If we can
293+
keep the tensor core pipeline fed, the 3.2x data reduction translates
294+
directly to a 3.2x speed advantage at the DRAM-bound limit.

0 commit comments

Comments
 (0)