@@ -122,20 +122,24 @@ than fp16 is at ~16 concurrent users.
122122
123123---
124124
125- ## Five -kernel strategy
125+ ## Four -kernel strategy
126126
127127Each kernel covers a range of M where it has a structural advantage.
128- The dispatch logic selects the best kernel per (layer_type, M) pair.
128+ The dispatch logic (` kbit_linear ` , ` kbit_expert_linear ` ) selects the
129+ best kernel per (layer_type, M) pair. All kernels read tiled format
130+ (from repack_kbit) with E4M4 absmax.
129131
130- | Kernel | M range | Layer types | Data format |
131- | --------| ---------| -------------| -------------|
132- | 1. Scalar GEMV | 1-4 | Dense, attention | Flat (quantize_kbit), float32 absmax |
133- | 2. MMA dequant | 5-16 | Dense, attention | Tiled (repack_kbit), E4M4 absmax |
134- | 3. Dequant + cuBLAS | 17+ | Dense, attention | Flat -> fp16 |
135- | 4. Grouped scalar GEMV | 1-4 | MoE experts | Flat (quantize_kbit), float32 absmax |
136- | 5. Grouped MMA | 1+ | MoE experts | Tiled (repack_kbit), E4M4 absmax |
132+ | Kernel | M range | Layer types | Dispatch function |
133+ | --------| ---------| -------------| -------------------|
134+ | 1. Scalar GEMV | 1-4 | Dense, attention | ` kbit_linear ` |
135+ | 2. MMA dequant | 5-16 | Dense, attention | ` kbit_linear ` |
136+ | 3. Dequant + cuBLAS | 17+ | Dense, attention | ` kbit_linear ` |
137+ | 4. Grouped MMA | 1-16 | MoE experts | ` kbit_expert_linear ` |
137138
138- Why five kernels instead of one:
139+ For MoE at max_M > 16, ` kbit_expert_linear ` falls back to per-expert
140+ dequant + cuBLAS matmul (no dedicated kernel needed).
141+
142+ Why four kernels instead of one:
139143- At M=1, tensor cores waste 94% of their compute (m16n8k16 pads 15
140144 zero rows). A scalar kernel that avoids MMA entirely wins by 3-5x.
141145- At M=5-16, MMA utilization rises to 31-100%. The 3.2x data
@@ -147,10 +151,6 @@ Why five kernels instead of one:
147151 its compute pipeline.
148152- MoE experts launched individually waste 88-97% of SMs. Grouping
149153 all active experts into one kernel launch solves this.
150- - The grouped scalar GEMV and grouped MMA serve complementary roles:
151- scalar wins at M=1-4 for moe_gu (K=2048, N=512) where its C=1
152- grid gives better parallelism; grouped MMA wins at all M for
153- moe_dn (K=512, N=2048) and at M>4 for moe_gu.
154154
155155** Practical importance (from workload analysis in ` token_analysis.md ` ):**
156156
@@ -194,17 +194,18 @@ range falls in the gap between these modes.
194194- No shared memory for B data, no cp.async, no split-K
195195
196196** Data format:**
197- - B_packed: flat from ` quantize_kbit ` — ` [N * num_k_blocks * k] ` uint32
198- - B_absmax: flat float32 — ` [N * num_k_blocks] `
199- - No repack step needed
197+ - B_packed: tiled from ` repack_kbit ` — tiles of ` [TILE_N × KB_PER_TILE × k] ` uint32
198+ - B_absmax: tiled uint8 E4M4 — tiles of ` [TILE_N × KB_PER_TILE] `
199+ - Supports both flat and tiled layouts via ` TILED ` template bool
200+ - Flat layout preserved for standalone use; tiled layout used by dispatch
200201
201202** Inner loop (V8):**
202203
203204Each thread strides through quantization blocks along K:
204205```
205206for each quant block (stride 64):
206207 load k bit-plane words (vectorized: int2 for k=2, int4 for k=4)
207- load float32 absmax
208+ load absmax (E4M4 → float decode)
208209
209210 for sub = 0..3: // 4 groups of 8 elements
210211 load A[m, k_base + sub*8 .. +7] via int4 (8 fp16 values)
@@ -329,12 +330,13 @@ the MMA dequant kernel takes ~68 us (instruction-limited, only 1.3%
329330of execution is MMA). A fused dequant kernel would take ~ 5 us for
330331this shape, so dequant + cuBLAS ~ 27 us would beat 68 us.
331332
332- ** Dequant kernel** (` kDequantizeBlockwise_kbit_vec ` ): a single CUDA
333- kernel that reads k-bit packed data + absmax and writes fp16 output.
334- Templated on absmax type: float32 (from ` quantize_kbit ` directly),
335- uint8 E4M4, or fp16. The float32 absmax path was added to eliminate
336- a previous Python-side E4M4 conversion that launched ~ 15 PyTorch
337- elementwise kernels (~ 800 us). Now it is a single kernel launch.
333+ ** Dequant kernel:** Two variants:
334+ - ` kDequantizeBlockwise_kbit_vec ` : reads flat layout (from quantize_kbit)
335+ - ` kDequantizeBlockwise_kbit_tiled ` : reads tiled layout (from repack_kbit)
336+
337+ Both are templated on absmax type (uint8 E4M4, fp16, float32) and
338+ output type (fp16, bf16, float32). The tiled variant is used by
339+ ` kbit_linear ` dispatch for the M>16 dequant+cuBLAS path.
338340
339341Dequant GPU kernel times (ncu-measured, k=4):
340342
@@ -352,51 +354,15 @@ to the matmul. At M>=64, dequant+cuBLAS wins because cuBLAS scales
352354efficiently while MMA is instruction-limited. The crossover is
353355M=32-64 depending on shape.
354356
355- ** Data format:** Uses flat layout (same as scalar GEMV). The
356- ` dequantize_kbit ` launcher handles float32, uint8 E4M4, and fp16
357- absmax via the ` _KBIT_ABSMAX_SUFFIX ` dispatch map.
358-
359- ---
360-
361- ## 4. Grouped scalar GEMV (` kbit_grouped_scalar_gemv ` )
362-
363- ** Location:** ` ops.cu ` (search for ` kbit_grouped_scalar_gemv ` )
364-
365- ** Operation:** For each expert e: C_e[ M_e, N] = A_e[ M_e, K] * W_e^T,
366- all experts in one kernel launch.
367-
368- ** Architecture (V8):**
369- - 64 threads (2 warps), one output column per block (C=1)
370- - Grid = (N, num_experts) — Y-dimension indexes experts
371- - ` __launch_bounds__(64, 24) ` for M<=2, ` __launch_bounds__(64, 16) ` for M>2
372- - M_VAL dispatch (1/2/3/4 templates)
373-
374- ** Data format:**
375- - B_packed_all: flat from ` quantize_kbit ` — concatenated per-expert,
376- each ` [N * num_k_blocks * k] ` uint32 (truncated to exact size)
377- - B_absmax_all: flat float32 — concatenated per-expert,
378- each ` [N * num_k_blocks] ` float32 (truncated to exact size)
379- - No repack step needed. Uses same flat layout as the dense scalar GEMV.
380-
381- ** Inner loop:** Identical to the dense scalar GEMV (V8): vectorized
382- int4 A loads, 4-group sub-loop of 8 elements, shuffle codebook lookup.
383- The only difference is per-expert pointer arithmetic using
384- ` expert_offsets[expert_id] ` to find each expert's A, B, and C regions.
385-
386- ** Why grouped scalar wins for moe_gu (K=2048, N=512) at M<=4:**
387- With C=1, the grid is N × num_experts = 512 × 8 = 4096 blocks. This
388- gives full SM utilization (32 blocks/SM). The grouped MMA at this shape
389- has far fewer blocks due to tiling overhead.
390-
391- ** Quantize_kbit padding:** ` quantize_kbit ` appends a small padding
392- (4 packed words + 1 absmax) to each expert's output. The test and
393- benchmark helpers truncate each expert's data to the exact expected
394- size before concatenation, so the kernel's arithmetic indexing
395- (` expert_id * N * num_k_blocks * K_BITS ` ) works correctly.
357+ ** Data format:** The flat variant (` dequantize_kbit ` ) reads flat layout
358+ from ` quantize_kbit ` . The tiled variant (` dequantize_kbit_tiled ` ) reads
359+ tiled layout from ` repack_kbit ` , used by ` kbit_linear ` dispatch.
360+ Both handle float32, uint8 E4M4, and fp16 absmax via the
361+ ` _KBIT_ABSMAX_SUFFIX ` dispatch map.
396362
397363---
398364
399- ## 5 . Grouped MMA (` kbit_grouped_gemm_prod ` )
365+ ## 4 . Grouped MMA (` kbit_grouped_gemm_prod ` )
400366
401367** Location:** ` ops.cu ` (search for ` kbit_grouped_gemm_prod ` )
402368
@@ -486,33 +452,31 @@ targets < 10 us.
486452
487453## Data formats
488454
489- Two formats exist, and which kernel uses which matters:
455+ All inference kernels read ** tiled format** (from ` repack_kbit ` ).
456+ Flat format exists only as the intermediate output of ` quantize_kbit `
457+ before repacking.
490458
491- ** Flat (from ` quantize_kbit ` ):**
459+ ** Flat (from ` quantize_kbit ` ) — intermediate only :**
492460- B_packed: ` [N * num_k_blocks * k] ` uint32, row-major per column
493- - B_absmax: ` [N * num_k_blocks] ` float32
494- - No preprocessing. Used by: scalar GEMV, grouped scalar GEMV,
495- dequant kernel .
461+ - B_absmax: ` [N * num_k_blocks] ` float32 or uint8 E4M4
462+ - Used only during quantization. Converted to tiled by ` repack_kbit `
463+ at model load time, then discarded .
496464
497- ** Tiled (from ` repack_kbit ` ):**
465+ ** Tiled (from ` repack_kbit ` ) — runtime format :**
498466- B_packed: reorganized into ` [k_tiles * n_tiles * TILE_N * B_COL_WORDS] `
499467 for coalesced cp.async loads per tile
500468- B_absmax: E4M4-encoded uint8, same tiled layout
501- - Requires a one-time repack pass. Used by: MMA kernel, grouped MMA
502- kernel.
469+ - Used by: scalar GEMV, MMA kernel, grouped MMA kernel,
470+ tiled dequant kernel (for dequant+cuBLAS path) .
503471
504472E4M4 encodes each float32 absmax as a single byte (4-bit exponent +
5054734-bit mantissa). Decode is branchless: ` ldexp(mantissa, exponent-bias) ` .
506474This saves 4x bandwidth for absmax reads but adds a decode step in
507475the inner loop.
508476
509- ** Note:** The grouped scalar GEMV and grouped MMA use different data
510- formats. The grouped scalar GEMV uses flat layout with float32 absmax
511- (same as the dense scalar GEMV), while the grouped MMA uses tiled
512- layout with E4M4 absmax (same as the dense MMA). This means MoE
513- expert weights must be stored in both formats if both kernels are used
514- in the dispatch, or a runtime conversion must happen. Currently the
515- benchmark prepares each format separately.
477+ The flat dequant kernel (` kDequantizeBlockwise_kbit_vec ` ) is still
478+ available for standalone use (e.g., debugging), but ` kbit_linear `
479+ dispatch uses the tiled dequant (` kDequantizeBlockwise_kbit_tiled ` ).
516480
517481---
518482
@@ -539,7 +503,7 @@ per element; for k=2, ~8 ops.
539503
540504| GPU | SM | MMA instruction | Async MMA? | Kernel strategy |
541505| -----| -----| -----------------| ------------| ----------------|
542- | RTX 4090 | sm_89 | mma.sync | No | All 5 kernels as described |
506+ | RTX 4090 | sm_89 | mma.sync | No | All 4 kernels as described |
543507| RTX 5090 | sm_120 | mma.sync (ext) | No | Same strategy, more SMs (192) |
544508| H100/H200 | sm_90a | wgmma.mma_async | Yes | Could overlap dequant + MMA |
545509| B200/GB200 | sm_100a | tcgen05.mma | Yes | Could overlap dequant + MMA |
0 commit comments