Skip to content

Commit 4d51152

Browse files
TimDettmersclaude
andcommitted
docs: Add optimization guide and update progress report
New optimization.md catalogs 5 performance optimizations with expected impact, implementation details, and recommended order: 1. Multi-M-block tiling (highest priority, 2-3x expected) 2. Larger N_BLOCKS per warp (2x, compounds with #1) 3. C output staging through shared memory (5-15%) 4. Persistent kernel (helps low-tile-count shapes) 5. cp.async for A tile (2-5%) Updated progress.md with current status section, table of contents for Stages 2-6, and pointer to optimization.md. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a91c313 commit 4d51152

File tree

2 files changed

+372
-11
lines changed

2 files changed

+372
-11
lines changed

optimization.md

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# kbit GEMM Kernel: Optimization Guide
2+
3+
This document catalogs the remaining performance optimizations for the
4+
production kbit GEMM kernel (`kbit_gemm_prod`). Each optimization is
5+
described with its expected impact, implementation approach, and testing
6+
strategy.
7+
8+
The kernel is functionally complete (fp16 + bf16, split-K, ldmatrix with
9+
swizzle, cp.async double-buffered pipeline, 139 tests passing). The
10+
remaining work is purely about throughput.
11+
12+
---
13+
14+
## Current State (Baseline)
15+
16+
**Kernel configuration:**
17+
- TILE_M = 16 (one MMA M-block per warp)
18+
- TILE_N = 128 (N_BLOCKS = 2, each warp covers 16 columns)
19+
- TILE_K = 64 (4 MMA k-sub-tiles of 16)
20+
- 256 threads = 8 warps, each warp handles the same M rows and a slice of N
21+
- Double-buffered cp.async pipeline
22+
- ldmatrix.x4 with XOR bank-conflict swizzle for A tile
23+
24+
**RTX 4090 benchmark (K=4, fp16, k_chunks=1):**
25+
26+
| M | K_dim | N | kbit (us) | cuBLAS (us) | Speedup |
27+
|---:|------:|------:|----------:|------------:|--------:|
28+
| 1 | 4096 | 4096 | 109 | 43 | 0.39x |
29+
| 1 | 4096 | 11008 | 82 | 128 | **1.56x** |
30+
| 4 | 4096 | 4096 | 92 | 22 | 0.24x |
31+
| 4 | 4096 | 11008 | 100 | 121 | **1.21x** |
32+
| 16 | 4096 | 4096 | 149 | 28 | 0.19x |
33+
34+
**Why it's slow for square matrices:** Each thread block computes a
35+
16x128 output tile. With M=16, only 1 M-tile exists, meaning only
36+
(N/128) blocks launch. For N=4096, that's 32 blocks on a 128-SM GPU --
37+
25% utilization. And each block does very little compute per shared
38+
memory load because TILE_M=16 means only one MMA row-block per warp.
39+
40+
**Why it wins for M=1 large-N:** The GEMM is memory-bandwidth-bound.
41+
The kernel reads 4-bit compressed weights (4x less data than fp16
42+
cuBLAS), which directly translates to speedup.
43+
44+
---
45+
46+
## Optimization 1: Multi-M-Block Tiling
47+
48+
**Priority: HIGHEST. This is the single biggest performance lever.**
49+
50+
### The Problem
51+
52+
Currently TILE_M=16. Each warp executes 2 MMA operations per k-sub-tile
53+
(N_BLOCKS=2). The A fragment is loaded once and used for only 2 MMAs.
54+
The compute-to-load ratio is low.
55+
56+
### The Fix
57+
58+
Template the kernel on `M_BLOCKS` (1, 2, 3, 4). TILE_M becomes
59+
`M_BLOCKS * 16`. Each warp handles multiple M-blocks, reusing the same
60+
B fragment across all of them:
61+
62+
```
63+
Current (M_BLOCKS=1):
64+
Each warp: 1 M-block x 2 N-blocks = 2 MMAs per k-sub-tile
65+
66+
Target (M_BLOCKS=4):
67+
Each warp: 4 M-blocks x 2 N-blocks = 8 MMAs per k-sub-tile
68+
```
69+
70+
The B fragment (dequantized from bit-planes) is the expensive part --
71+
codebook lookup via shuffle, absmax multiply. With M_BLOCKS=4, this cost
72+
is amortized over 4x more MMA operations.
73+
74+
### Implementation
75+
76+
1. Add `M_BLOCKS` template parameter to `kbit_gemm_prod`
77+
2. FragC accumulator becomes `float frag_c[M_BLOCKS][N_BLOCKS][4]`
78+
3. A fragment loading: load `M_BLOCKS` fragments per k-sub-tile (ldmatrix
79+
for each M-block's 16 rows)
80+
4. Inner loop: for each B fragment, iterate over M_BLOCKS and issue MMA
81+
5. A tile in shared memory grows: `M_BLOCKS * 16 * TILE_K * sizeof(scalar_t)`
82+
6. Output write: iterate over M_BLOCKS for the C tile write
83+
7. Host-side dispatch selects M_BLOCKS based on M:
84+
- M <= 16: M_BLOCKS=1
85+
- M <= 32: M_BLOCKS=2
86+
- M <= 48: M_BLOCKS=3
87+
- M >= 49: M_BLOCKS=4
88+
89+
### Shared Memory Impact
90+
91+
| M_BLOCKS | TILE_M | A tile (bytes) | B tile K=4 | Absmax | Per stage | 2 stages |
92+
|---------:|-------:|---------------:|-----------:|-------:|----------:|---------:|
93+
| 1 | 16 | 2,048 | 4,096 | 256 | 6,400 | 12,800 |
94+
| 2 | 32 | 4,096 | 4,096 | 256 | 8,448 | 16,896 |
95+
| 4 | 64 | 8,192 | 4,096 | 256 | 12,544 | 25,088 |
96+
97+
All fit within RTX 4090's 100 KB limit. Even M_BLOCKS=4 with 4 pipeline
98+
stages would use ~50 KB.
99+
100+
### Register Impact
101+
102+
FragC grows from 2*4 = 8 floats to M_BLOCKS*2*4 = 32 floats for M_BLOCKS=4.
103+
FragA grows from 4 uint32 to M_BLOCKS*4 = 16 uint32. Total registers ~50-60,
104+
well within the 255 limit.
105+
106+
### Expected Speedup
107+
108+
For M=4, K_dim=4096, N=4096 with M_BLOCKS=4: each block does 4x more compute
109+
per B tile load. Since the kernel is currently B-load-limited for these sizes,
110+
expect roughly **2-3x improvement** (not full 4x due to diminishing returns
111+
from A tile growth).
112+
113+
### Test Strategy
114+
115+
- M_BLOCKS=1 must produce identical output to the current kernel (bit-exact)
116+
- M_BLOCKS=2,3,4 must match Python reference within existing tolerance
117+
- Test partial M-tiles: M=5 with M_BLOCKS=4 (TILE_M=64, only 5 rows valid)
118+
119+
---
120+
121+
## Optimization 2: Larger N_BLOCKS per Warp
122+
123+
**Priority: HIGH. Complements multi-M-block.**
124+
125+
### The Problem
126+
127+
Currently N_BLOCKS=2, so each warp covers 16 of the 128 tile columns.
128+
With 8 warps, that's 8*16 = 128 columns (full tile). But each warp
129+
only issues 2 MMA ops per k-sub-tile per M-block.
130+
131+
### The Fix
132+
133+
Increase N_BLOCKS to 4 (each warp covers 32 columns). Then 4 warps
134+
cover the full TILE_N=128. The remaining 4 warps cover additional M
135+
rows (for the 2-warps-along-M x 4-warps-along-N layout from the
136+
design doc).
137+
138+
### Warp Layout
139+
140+
The design doc specifies for TILE_M=64, TILE_N=128:
141+
142+
```
143+
2 warps along M (each handles 32 rows = 2 M-blocks)
144+
x 4 warps along N (each handles 32 cols = 4 N-blocks)
145+
= 8 warps total
146+
147+
Each warp: 2 M-blocks x 4 N-blocks = 8 MMAs per k-sub-tile
148+
With TILE_K=64 (4 k-sub-tiles): 32 MMAs per warp per K-tile
149+
```
150+
151+
This is the target configuration. Combined with multi-M-block, it gives
152+
each warp 4x more compute than the current kernel.
153+
154+
### Implementation
155+
156+
1. Change N_BLOCKS to 4
157+
2. Change warp-to-tile mapping: `warp_m = warp_id / 4`, `warp_n = warp_id % 4`
158+
3. Each warp handles M-blocks `[warp_m * M_BLOCKS_PER_WARP ... (warp_m+1) * M_BLOCKS_PER_WARP - 1]`
159+
and N-blocks `[warp_n * 4 ... warp_n * 4 + 3]`
160+
4. Fragment accumulators: `frag_c[M_BLOCKS_PER_WARP][4][4]`
161+
162+
### Expected Speedup
163+
164+
Combined with multi-M-block: each thread block does **8x** more compute
165+
per B tile load compared to current (4x from M, 2x from N). For M>=4
166+
square matrices, expect the kernel to **match or beat cuBLAS**.
167+
168+
---
169+
170+
## Optimization 3: C Output Staging Through Shared Memory
171+
172+
**Priority: MEDIUM. Improves memory write efficiency.**
173+
174+
### The Problem
175+
176+
Currently, each thread writes its FragC values directly to global memory.
177+
The MMA fragment layout means threads in a warp write to scattered row
178+
positions:
179+
- Thread with gid=0 writes rows 0, 8
180+
- Thread with gid=1 writes rows 1, 9
181+
- etc.
182+
183+
These writes hit different cache lines (each row is N*2 bytes apart),
184+
causing uncoalesced writes.
185+
186+
### The Fix
187+
188+
After the K-tile loop, stage the output through shared memory:
189+
190+
1. Each warp writes its FragC values to shared memory in the natural
191+
fragment order (scattered rows, but shmem is fast)
192+
2. `__syncthreads()`
193+
3. All threads cooperatively read from shared memory in row-major order
194+
and write to global memory with coalesced access (consecutive threads
195+
write consecutive addresses within the same row)
196+
197+
### Shared Memory Reuse
198+
199+
The pipeline's shared memory is no longer needed during the output phase
200+
(the K-tile loop is done). The C staging area can reuse the pipeline
201+
buffers. For TILE_M=64, TILE_N=128, the C tile is 64*128*2 = 16 KB in
202+
fp16, which fits easily in one pipeline stage's allocation.
203+
204+
### Expected Speedup
205+
206+
Moderate. The output write is not on the critical path for large K_dim
207+
(the K-tile loop dominates). For small K_dim or when the kernel is
208+
already close to bandwidth-optimal, this can give **5-15% improvement**.
209+
210+
---
211+
212+
## Optimization 4: Persistent Kernel
213+
214+
**Priority: MEDIUM. Helps SM utilization for small tile counts.**
215+
216+
### The Problem
217+
218+
The current 2D/3D grid launch creates one block per output tile (or per
219+
split-K chunk). When the number of tiles is less than the GPU's SM count,
220+
SMs sit idle.
221+
222+
### The Fix
223+
224+
Launch exactly `num_SMs` blocks. Each block loops over assigned work items
225+
(linearized (m_tile, n_tile, k_chunk) triples). Benefits:
226+
227+
1. **Better utilization:** All SMs are always active
228+
2. **Accumulator persistence:** When consecutive work items share the same
229+
output tile, the accumulators stay in registers (no atomicAdd needed)
230+
3. **First-contributor optimization:** The first block to write a tile does
231+
a plain store to the fp32 workspace (no need to zero it first). Only
232+
subsequent contributors use atomicAdd.
233+
234+
### Implementation
235+
236+
See design doc Section 6 for the full design. The key structure:
237+
238+
```cpp
239+
int total_work = m_tiles * n_tiles * k_chunks;
240+
int work_per_block = div_ceil(total_work, gridDim.x);
241+
int my_start = blockIdx.x * work_per_block;
242+
int my_end = min(my_start + work_per_block, total_work);
243+
244+
int prev_mn = -1;
245+
for (int work_id = my_start; work_id < my_end; work_id++) {
246+
int mn_id = work_id / k_chunks;
247+
int k_chunk_id = work_id % k_chunks;
248+
if (mn_id != prev_mn) {
249+
if (prev_mn >= 0) write_output(...);
250+
zero_accumulators();
251+
prev_mn = mn_id;
252+
}
253+
process_k_range(k_chunk_id, ...);
254+
}
255+
if (prev_mn >= 0) write_output(...);
256+
```
257+
258+
### Expected Speedup
259+
260+
Depends on the shape. For shapes where `m_tiles * n_tiles < num_SMs`
261+
(e.g., M=16, N=4096 on a 128-SM GPU: 1*32=32 tiles), the persistent
262+
kernel can **2-3x** improve throughput by enabling split-K without the
263+
atomicAdd overhead. For shapes with many tiles, the benefit is marginal.
264+
265+
---
266+
267+
## Optimization 5: cp.async for A Tile
268+
269+
**Priority: LOW. Minor improvement.**
270+
271+
### The Problem
272+
273+
Currently A is loaded synchronously (element-by-element) while B and
274+
absmax use cp.async. A could also use cp.async for better latency hiding.
275+
276+
### The Complication
277+
278+
A needs bounds checking (`gr < M && gc < K_dim`) and XOR swizzle on the
279+
destination address. cp.async copies from a source address to a destination
280+
address, so the swizzle can be applied to the destination. But bounds
281+
checking is harder -- cp.async doesn't support conditional copies.
282+
283+
### Possible Approach
284+
285+
Use `cp.async.cg.shared.global` for the interior of the A tile (rows that
286+
are guaranteed in-bounds), and synchronous loads only for boundary rows.
287+
For TILE_M=64 and M=4096, almost all rows are in-bounds. Only the last
288+
M-tile may have boundary rows.
289+
290+
### Expected Speedup
291+
292+
Small (2-5%). A tile is only 2-8 KB per stage, much smaller than B tile.
293+
The synchronous load latency is already partially hidden by the pipeline.
294+
295+
---
296+
297+
## Recommended Implementation Order
298+
299+
1. **Multi-M-block tiling** (Optimization 1) -- biggest impact, enables the
300+
target warp layout
301+
2. **Larger N_BLOCKS** (Optimization 2) -- natural companion to multi-M-block,
302+
together they achieve the design doc's target of 32 MMAs per warp per K-tile
303+
3. **C output staging** (Optimization 3) -- polish for write efficiency
304+
4. **Persistent kernel** (Optimization 4) -- improves edge cases
305+
5. **cp.async for A** (Optimization 5) -- diminishing returns
306+
307+
After optimizations 1+2, re-benchmark. If the kernel matches cuBLAS for
308+
M=1-32 with large N, the remaining optimizations can be deprioritized in
309+
favor of integration work (wiring into Linear4bit, auto-tuning k_chunks).
310+
311+
---
312+
313+
## Integration Work (Not Performance, But Required)
314+
315+
These are not performance optimizations but are needed to ship:
316+
317+
- **Wire into LinearNbit module:** Replace the dequant+cuBLAS path with a
318+
call to `kbit_gemm_prod` when conditions are met (CUDA, fp16/bf16,
319+
N % 128 == 0, K_dim % 64 == 0)
320+
- **Auto-select k_chunks:** Based on M, N, K_dim, and SM count. Formula
321+
from design doc Section 6.2.
322+
- **Remove staging kernels:** Clean up Stages 3-5 kernels, keeping only
323+
the production kernel and the debug MMA test
324+
- **Lint + PR:** Run ruff/clang-format, merge to main

0 commit comments

Comments
 (0)