Skip to content

Commit 90cd7cf

Browse files
TimDettmersclaude
andcommitted
docs: Add SASS analysis and inner loop optimization steps
SASS analysis of K=4 M_BLOCKS=2 half kernel reveals: - 39:1 ALU to tensor core instruction ratio - Bit extraction creates 12-deep dependency chain (fixable to depth 4) - decode_e4m4_absmax branches generate 512 BSSY/BSYNC pairs per block - 12.5% occupancy limits latency hiding to 2 warps per scheduler Added Step 4 (bit extraction fix), Step 4b (branchless absmax), and Step 4c (B fragment double-buffering) to optimization plan. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f301ba1 commit 90cd7cf

File tree

1 file changed

+118
-23
lines changed

1 file changed

+118
-23
lines changed

optimization.md

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,62 @@ With k_splits=2 and 16 k_tiles per split, the pipeline is even shorter.
9090
A 3-stage pipeline (instead of 2) provides 2x more latency slack at
9191
the cost of 1 more prefill iteration.
9292

93-
### 3.3 Dequant compute cost
93+
### 3.3 Dequant compute cost (SASS analysis, K=4 M_BLOCKS=2 fp16)
9494

95-
Per weight element: ~13 ALU ops (bit extract + shuffle codebook + scale).
96-
cuBLAS does 0 ops per weight element (just feeds fp16 to MMA). This is
97-
inherent and cannot be eliminated — it is the price of compression.
95+
The compiled kernel has **1264 SASS instructions**. The instruction mix:
9896

99-
But the dequant runs on INT32/FP16 ALU while MMA runs on tensor cores.
100-
They are different functional units. With proper scheduling (B fragment
101-
double-buffering, deeper pipeline), the dequant can overlap with MMA
102-
and memory loads. Currently the dequant is on the critical path because
103-
the inner loop is sequential: load B → dequant → MMA → next N-block.
97+
| Category | Count | % | What |
98+
|----------|------:|---:|------|
99+
| Bit manipulation (SHF+LOP3+IMAD) | 628 | 57% | Dequant + address math |
100+
| Tensor core (HMMA) | 16 | 1.5% | The actual matmul |
101+
| Codebook + scale (SHFL+HMUL2) | 64 | 5.8% | Shuffle lookup + absmax multiply |
102+
| Type conversion (F2FP+F2I+I2F) | 40 | 3.6% | Absmax decode, half↔float |
103+
| Control flow (BRA+BSSY+BSYNC+ISETP) | 187 | 17% | Branches, divergence, compares |
104+
| Memory (LDS+LDSM+LDGSTS+PRMT) | 147 | 13% | Shmem, cp.async, permutes |
105+
106+
**The kernel is 39:1 ALU:tensor-core.** The tensor cores are idle 98.5%
107+
of the time. Three specific problems:
108+
109+
**Problem 1: Bit extraction dependency chain.** The inner loop:
110+
```cpp
111+
for (int b = 0; b < K_BITS; b++)
112+
idx |= ((planes[b] >> bit_pos) & 1) << b;
113+
```
114+
Each `idx |=` depends on the previous value of `idx`, creating a serial
115+
chain of ~12 dependent operations for K=4. With only 2 warps per
116+
scheduler (occupancy = 12.5%), pipeline stalls of 2 cycles per
117+
dependent pair cannot be hidden. For 32 elements per TILE_K × 32
118+
k_tiles: estimated **~10us of dependency stalls**.
119+
120+
Fix: restructure to a tree reduction with independent extractions:
121+
```cpp
122+
int b0 = (planes[0] >> bit_pos) & 1; // 4 independent extractions
123+
int b1 = (planes[1] >> bit_pos) & 1;
124+
int b2 = (planes[2] >> bit_pos) & 1;
125+
int b3 = (planes[3] >> bit_pos) & 1;
126+
int idx = b0 | (b1 << 1) | (b2 << 2) | (b3 << 3); // tree combine
127+
```
128+
This reduces the dependency chain from depth 12 to depth 4. With LOP3
129+
(3-input boolean), the combine is 2 instructions.
130+
131+
**Problem 2: Branchy absmax decode.** `decode_e4m4_absmax` has two
132+
conditional branches (`if raw == 0`, `if e == 0`) that generate 16
133+
BSSY/BSYNC divergence-handling pairs per TILE_K iteration. These
134+
execute 512 times per block (16 × 32 k_tiles). Even when never taken,
135+
each pair costs ~4-6 cycles of convergence overhead = **~2-3us total**.
136+
137+
Fix: make it branchless — compute the normal-path result unconditionally,
138+
then use predicated select for the edge cases (or just accept that
139+
raw=0 and subnormal absmax are negligibly rare and let the normal
140+
formula handle them, producing a harmless wrong value for impossible
141+
inputs).
142+
143+
**Problem 3: Low occupancy.** 72 registers per thread × 256 threads =
144+
18,432 registers per block. The SM has 65,536 registers, so only 3
145+
blocks fit... but shared memory limits it to 1 block (8 warps). With
146+
4 warp schedulers, each has only 2 warps to choose from. Every memory
147+
or ALU latency that both warps hit simultaneously leaves the scheduler
148+
idle. cuBLAS typically runs at 25-50% occupancy for comparable shapes.
104149

105150
---
106151

@@ -230,29 +275,79 @@ compute stalls, barrier stalls, or something else entirely.
230275
This informs whether further optimization should focus on memory access
231276
patterns, compute scheduling, or pipeline structure.
232277

233-
### Step 4: B fragment register double-buffering (HIGH)
278+
### Step 4: Fix bit extraction dependency chain (HIGH)
234279

235-
**What:** Overlap shmem B loads with MMA execution in the inner loop.
280+
**What:** Restructure the inner loop bit extraction to eliminate the
281+
serial `idx |=` dependency chain.
236282

237-
Current inner loop (per k_sub_tile, per N_block):
283+
**Current code** (12-deep dependency chain for K=4):
284+
```cpp
285+
int idx = 0;
286+
for (int b = 0; b < K_BITS; b++)
287+
idx |= ((planes[b] >> bit_pos) & 1) << b;
238288
```
239-
load B planes from shmem → dequant → MMA → next N_block
240-
[stall] [ALU] [TC]
289+
290+
**Fixed code** (4-deep, independent extractions + tree combine):
291+
```cpp
292+
int b0 = (planes[0] >> bit_pos) & 1;
293+
int b1 = (planes[1] >> bit_pos) & 1;
294+
int b2 = (planes[2] >> bit_pos) & 1;
295+
int b3 = (planes[3] >> bit_pos) & 1;
296+
int idx = b0 | (b1 << 1) | (b2 << 2) | (b3 << 3);
241297
```
242298

243-
With double-buffering:
299+
The 4 extractions are independent (no data dependency). The compiler
300+
can schedule them across pipeline stages. The combine uses LOP3 (2
301+
instructions for 4-input OR with shifts). Dependency depth: 4 vs 12.
302+
303+
For K=2,3,5: same pattern with 2,3,5 independent extractions.
304+
305+
**Also fix: process 4 elements with interleaved extractions.** Currently
306+
the inner loop processes elements r=0..3 sequentially. Interleaving
307+
the bit extraction across elements increases ILP further — while
308+
element 0's extraction stalls on ALU latency, element 1's extraction
309+
can issue.
310+
311+
**Expected impact:** 15-25% improvement on all shapes by reducing
312+
dependency stalls from ~10us to ~3-4us per 32 k_tiles.
313+
314+
### Step 4b: Branchless absmax decode (HIGH)
315+
316+
**What:** Remove the two conditional branches in `decode_e4m4_absmax`.
317+
318+
**Current code** (generates 16 BSSY/BSYNC pairs per TILE_K):
319+
```cpp
320+
if (raw == 0) return 0.0f; // branch + convergence
321+
int e = raw >> 4;
322+
int m = raw & 0xF;
323+
if (e == 0) return ldexpf(...); // branch + convergence
244324
```
245-
preload B[nb+1] from shmem → dequant B[nb] → MMA → next
246-
[shmem load] [ALU, overlap] [TC]
325+
326+
**Fixed code** (branchless, uses bit manipulation):
327+
```cpp
328+
int e = raw >> 4;
329+
int m = raw & 0xF;
330+
// Normal path: construct IEEE 754 directly
331+
unsigned int ieee = (unsigned int)(e - E4M4_BIAS + 127) << 23
332+
| (unsigned int)m << 19;
333+
float result = __uint_as_float(ieee);
334+
// Predicated zero-out for raw == 0 (no branch)
335+
result = (raw == 0) ? 0.0f : result;
247336
```
248337

249-
The shmem loads for the next N_block's B planes (4 uint32 reads,
250-
~20-30 cycle latency each) overlap with the current N_block's dequant
251-
ALU work. This removes the shmem load stall from the critical path.
338+
Drop subnormal handling entirely (e==0 produces absmax < 2^-10 which
339+
is effectively zero for quantized weights — no real weight block has
340+
absmax this small).
341+
342+
**Expected impact:** 5-10% improvement from eliminating 512 BSSY/BSYNC
343+
convergence points per block.
344+
345+
### Step 4c: B fragment register double-buffering (HIGH)
346+
347+
**What:** Preload next N_block's B planes from shmem while current
348+
dequant ALU runs. Hides 20-30 cycle shmem load latency.
252349

253-
**Expected impact:** 10-20% improvement on all shapes. The dequant ALU
254-
work (~50 cycles per N_block iteration) provides enough instructions to
255-
hide the shmem load latency.
350+
**Expected impact:** 10-15% improvement on all shapes.
256351

257352
### Step 5: TILE_N=256 + TILE_K=128 for large shapes (HIGH)
258353

0 commit comments

Comments
 (0)