@@ -90,17 +90,62 @@ With k_splits=2 and 16 k_tiles per split, the pipeline is even shorter.
9090A 3-stage pipeline (instead of 2) provides 2x more latency slack at
9191the cost of 1 more prefill iteration.
9292
93- ### 3.3 Dequant compute cost
93+ ### 3.3 Dequant compute cost (SASS analysis, K=4 M_BLOCKS=2 fp16)
9494
95- Per weight element: ~ 13 ALU ops (bit extract + shuffle codebook + scale).
96- cuBLAS does 0 ops per weight element (just feeds fp16 to MMA). This is
97- inherent and cannot be eliminated — it is the price of compression.
95+ The compiled kernel has ** 1264 SASS instructions** . The instruction mix:
9896
99- But the dequant runs on INT32/FP16 ALU while MMA runs on tensor cores.
100- They are different functional units. With proper scheduling (B fragment
101- double-buffering, deeper pipeline), the dequant can overlap with MMA
102- and memory loads. Currently the dequant is on the critical path because
103- the inner loop is sequential: load B → dequant → MMA → next N-block.
97+ | Category | Count | % | What |
98+ | ----------| ------:| ---:| ------|
99+ | Bit manipulation (SHF+LOP3+IMAD) | 628 | 57% | Dequant + address math |
100+ | Tensor core (HMMA) | 16 | 1.5% | The actual matmul |
101+ | Codebook + scale (SHFL+HMUL2) | 64 | 5.8% | Shuffle lookup + absmax multiply |
102+ | Type conversion (F2FP+F2I+I2F) | 40 | 3.6% | Absmax decode, half↔float |
103+ | Control flow (BRA+BSSY+BSYNC+ISETP) | 187 | 17% | Branches, divergence, compares |
104+ | Memory (LDS+LDSM+LDGSTS+PRMT) | 147 | 13% | Shmem, cp.async, permutes |
105+
106+ ** The kernel is 39:1 ALU: tensor-core .** The tensor cores are idle 98.5%
107+ of the time. Three specific problems:
108+
109+ ** Problem 1: Bit extraction dependency chain.** The inner loop:
110+ ``` cpp
111+ for (int b = 0 ; b < K_BITS; b++)
112+ idx |= ((planes[b] >> bit_pos) & 1 ) << b;
113+ ```
114+ Each ` idx |= ` depends on the previous value of ` idx ` , creating a serial
115+ chain of ~ 12 dependent operations for K=4. With only 2 warps per
116+ scheduler (occupancy = 12.5%), pipeline stalls of 2 cycles per
117+ dependent pair cannot be hidden. For 32 elements per TILE_K × 32
118+ k_tiles: estimated ** ~ 10us of dependency stalls** .
119+
120+ Fix: restructure to a tree reduction with independent extractions:
121+ ``` cpp
122+ int b0 = (planes[0 ] >> bit_pos) & 1 ; // 4 independent extractions
123+ int b1 = (planes[1 ] >> bit_pos) & 1 ;
124+ int b2 = (planes[2 ] >> bit_pos) & 1 ;
125+ int b3 = (planes[3 ] >> bit_pos) & 1 ;
126+ int idx = b0 | (b1 << 1 ) | (b2 << 2 ) | (b3 << 3 ); // tree combine
127+ ```
128+ This reduces the dependency chain from depth 12 to depth 4. With LOP3
129+ (3-input boolean), the combine is 2 instructions.
130+
131+ ** Problem 2: Branchy absmax decode.** ` decode_e4m4_absmax ` has two
132+ conditional branches (` if raw == 0 ` , ` if e == 0 ` ) that generate 16
133+ BSSY/BSYNC divergence-handling pairs per TILE_K iteration. These
134+ execute 512 times per block (16 × 32 k_tiles). Even when never taken,
135+ each pair costs ~ 4-6 cycles of convergence overhead = ** ~ 2-3us total** .
136+
137+ Fix: make it branchless — compute the normal-path result unconditionally,
138+ then use predicated select for the edge cases (or just accept that
139+ raw=0 and subnormal absmax are negligibly rare and let the normal
140+ formula handle them, producing a harmless wrong value for impossible
141+ inputs).
142+
143+ ** Problem 3: Low occupancy.** 72 registers per thread × 256 threads =
144+ 18,432 registers per block. The SM has 65,536 registers, so only 3
145+ blocks fit... but shared memory limits it to 1 block (8 warps). With
146+ 4 warp schedulers, each has only 2 warps to choose from. Every memory
147+ or ALU latency that both warps hit simultaneously leaves the scheduler
148+ idle. cuBLAS typically runs at 25-50% occupancy for comparable shapes.
104149
105150---
106151
@@ -230,29 +275,79 @@ compute stalls, barrier stalls, or something else entirely.
230275This informs whether further optimization should focus on memory access
231276patterns, compute scheduling, or pipeline structure.
232277
233- ### Step 4: B fragment register double-buffering (HIGH)
278+ ### Step 4: Fix bit extraction dependency chain (HIGH)
234279
235- ** What:** Overlap shmem B loads with MMA execution in the inner loop.
280+ ** What:** Restructure the inner loop bit extraction to eliminate the
281+ serial ` idx |= ` dependency chain.
236282
237- Current inner loop (per k_sub_tile, per N_block):
283+ ** Current code** (12-deep dependency chain for K=4):
284+ ``` cpp
285+ int idx = 0 ;
286+ for (int b = 0 ; b < K_BITS; b++)
287+ idx |= ((planes[b] >> bit_pos) & 1 ) << b;
238288```
239- load B planes from shmem → dequant → MMA → next N_block
240- [stall] [ALU] [TC]
289+
290+ ** Fixed code** (4-deep, independent extractions + tree combine):
291+ ``` cpp
292+ int b0 = (planes[0 ] >> bit_pos) & 1 ;
293+ int b1 = (planes[1 ] >> bit_pos) & 1 ;
294+ int b2 = (planes[2 ] >> bit_pos) & 1 ;
295+ int b3 = (planes[3 ] >> bit_pos) & 1 ;
296+ int idx = b0 | (b1 << 1 ) | (b2 << 2 ) | (b3 << 3 );
241297```
242298
243- With double-buffering:
299+ The 4 extractions are independent (no data dependency). The compiler
300+ can schedule them across pipeline stages. The combine uses LOP3 (2
301+ instructions for 4-input OR with shifts). Dependency depth: 4 vs 12.
302+
303+ For K=2,3,5: same pattern with 2,3,5 independent extractions.
304+
305+ ** Also fix: process 4 elements with interleaved extractions.** Currently
306+ the inner loop processes elements r=0..3 sequentially. Interleaving
307+ the bit extraction across elements increases ILP further — while
308+ element 0's extraction stalls on ALU latency, element 1's extraction
309+ can issue.
310+
311+ ** Expected impact:** 15-25% improvement on all shapes by reducing
312+ dependency stalls from ~ 10us to ~ 3-4us per 32 k_tiles.
313+
314+ ### Step 4b: Branchless absmax decode (HIGH)
315+
316+ ** What:** Remove the two conditional branches in ` decode_e4m4_absmax ` .
317+
318+ ** Current code** (generates 16 BSSY/BSYNC pairs per TILE_K):
319+ ``` cpp
320+ if (raw == 0 ) return 0 .0f ; // branch + convergence
321+ int e = raw >> 4 ;
322+ int m = raw & 0xF ;
323+ if (e == 0 ) return ldexpf(...); // branch + convergence
244324```
245- preload B[nb+1] from shmem → dequant B[nb] → MMA → next
246- [shmem load] [ALU, overlap] [TC]
325+
326+ ** Fixed code** (branchless, uses bit manipulation):
327+ ``` cpp
328+ int e = raw >> 4 ;
329+ int m = raw & 0xF ;
330+ // Normal path: construct IEEE 754 directly
331+ unsigned int ieee = (unsigned int )(e - E4M4_BIAS + 127 ) << 23
332+ | (unsigned int )m << 19 ;
333+ float result = __uint_as_float(ieee);
334+ // Predicated zero-out for raw == 0 (no branch)
335+ result = (raw == 0 ) ? 0 .0f : result;
247336```
248337
249- The shmem loads for the next N_block's B planes (4 uint32 reads,
250- ~ 20-30 cycle latency each) overlap with the current N_block's dequant
251- ALU work. This removes the shmem load stall from the critical path.
338+ Drop subnormal handling entirely (e==0 produces absmax < 2^-10 which
339+ is effectively zero for quantized weights — no real weight block has
340+ absmax this small).
341+
342+ ** Expected impact:** 5-10% improvement from eliminating 512 BSSY/BSYNC
343+ convergence points per block.
344+
345+ ### Step 4c: B fragment register double-buffering (HIGH)
346+
347+ ** What:** Preload next N_block's B planes from shmem while current
348+ dequant ALU runs. Hides 20-30 cycle shmem load latency.
252349
253- ** Expected impact:** 10-20% improvement on all shapes. The dequant ALU
254- work (~ 50 cycles per N_block iteration) provides enough instructions to
255- hide the shmem load latency.
350+ ** Expected impact:** 10-15% improvement on all shapes.
256351
257352### Step 5: TILE_N=256 + TILE_K=128 for large shapes (HIGH)
258353
0 commit comments