Skip to content

Commit e78a28c

Browse files
TimDettmersclaude
andcommitted
docs: Update kernel spec for format unification and dispatch
Remove grouped scalar GEMV references (kernel removed in ac7d6ff). Update five-kernel → four-kernel strategy. Document tiled-only runtime format and kbit_linear/kbit_expert_linear dispatch functions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent abd7c7f commit e78a28c

File tree

4 files changed

+67
-139
lines changed

4 files changed

+67
-139
lines changed

benchmarking-report.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@ All kernel times are NCU `gpu__time_duration.avg` unless stated otherwise.
55

66
## Kernel dispatch
77

8-
Five kernels cover the full inference workload. Dispatch selects the fastest
9-
kernel per (layer_type, M) pair:
8+
Four kernels cover the full inference workload. `kbit_linear` and
9+
`kbit_expert_linear` dispatch to the fastest kernel per (layer_type, M):
1010

1111
| Kernel | M range | Layers | Status |
1212
|--------|---------|--------|--------|
1313
| Scalar GEMV | 1-4 | Dense + attention | Done (V8), 1.5-1.9x faster than fp16 at M=1 |
1414
| MMA dequant | 5-16 | Dense + attention | Done, ~1.0-1.3x vs fp16 |
1515
| Dequant + cuBLAS | 17+ | Dense + attention | Done, ~0.95-1.0x vs fp16 |
16-
| Grouped scalar GEMV | 1-4 | MoE experts | Done, competitive with fp16 |
17-
| Grouped MMA | 5+ | MoE experts | Done, competitive with fp16 |
16+
| Grouped MMA | 1-16 | MoE experts | Done, competitive with fp16 |
17+
18+
All kernels read tiled format (from `repack_kbit`) with E4M4 absmax.
1819

1920
## Per-shape speedups at M=1 (decode, dominant workload)
2021

@@ -166,10 +167,10 @@ kernel launches (dequant + matmul), doubling the dispatch tax.
166167
end-to-end throughput by up to 1.5x on top of the current kernel
167168
speedups.
168169

169-
4. **MoE grouped kernels need V8 optimizations.** The grouped scalar GEMV
170-
currently matches fp16 but does not beat it. Porting the V8 inner loop
171-
(vectorized A loads, 2-warp config, M-dispatch) would bring it closer
172-
to the 1.5-1.9x speedups seen on dense layers.
170+
4. **MoE dispatch is unified.** The grouped MMA handles M<=16 for MoE
171+
layers; for larger M, `kbit_expert_linear` falls back to per-expert
172+
dequant + cuBLAS matmul. The grouped scalar GEMV was removed (it only
173+
won one shape at M=1 by 0.3 us).
173174

174175
5. **Lower k is strictly better for inference speed, not just model size.**
175176
k=2 is fastest at every M value because it reads the least data. The

deployment-summary.md

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@ is 80-84% of total GEMM wall-clock time in typical sessions.
1414
At 16+ concurrent users, the advantage disappears because large prefill
1515
chunks dominate and the dequant overhead exceeds the bandwidth savings.
1616

17-
The system uses **5 CUDA kernels** dispatched per (layer_type, M):
17+
The system uses **4 CUDA kernels** dispatched by `kbit_linear` and
18+
`kbit_expert_linear` per (layer_type, M). All kernels read tiled
19+
format (from `repack_kbit`) with E4M4 absmax:
1820

1921
| Kernel | M range | Layers | Mechanism |
2022
|--------|---------|--------|-----------|
2123
| Scalar GEMV | 1-4 | Dense + attn | 64 threads, shuffle codebook, no tensor cores |
2224
| MMA dequant | 5-16 | Dense + attn | Tensor core m16n8k16, inline dequant |
2325
| Dequant + cuBLAS | 17+ | Dense + attn | Separate dequant kernel → cuBLAS GEMM |
24-
| Grouped scalar GEMV | 1-4 | MoE experts | Same as scalar, batched across experts |
25-
| Grouped MMA | 1+ | MoE experts | Same as MMA, batched across experts |
26+
| Grouped MMA | 1-16 | MoE experts | Same as MMA, batched across experts |
2627

27-
For MoE layers at large M (prefill), the grouped MMA kernel loses to
28-
fp16 BMM, so a hybrid dequant + cuBLAS BMM path is available.
28+
For MoE layers at max_M > 16 (prefill), `kbit_expert_linear` falls
29+
back to per-expert dequant + cuBLAS matmul.
2930

3031
---
3132

@@ -191,40 +192,6 @@ fp16, the same model requires 140 GB (two H100s or four 4090s).
191192

192193
---
193194

194-
## Grouped scalar GEMV: where it fits
195-
196-
The grouped scalar GEMV (`kbit_grouped_scalar_gemv`) is a specialized
197-
kernel for MoE expert layers at M=1-4. It uses the same flat data format
198-
and shuffle codebook as the dense scalar GEMV.
199-
200-
### When it wins
201-
202-
Only for **moe_gu (K=2048, N=512) at M=1** — and barely:
203-
204-
| Shape | M | Grouped scalar | Grp MMA | fp16 BMM | Winner |
205-
|-------|---|---------------|---------|----------|--------|
206-
| moe_gu | 1 | **11.3** | 11.6 | 11.7 | Grouped (by 0.3 us) |
207-
| moe_gu | 2 | 12.9 | **11.8** | 12.7 | Grp MMA |
208-
| moe_gu | 4 | 17.1 | **11.9** | 18.9 | Grp MMA |
209-
| moe_dn | 1 | 24.9 | **12.1** | 13.1 | Grp MMA |
210-
| moe_dn | 4 | 38.3 | **12.1** | 12.1 | Grp MMA |
211-
212-
The grouped scalar is terrible on moe_dn (K=512): with only 512/64=8
213-
quant blocks per thread and C=1 (one column per block), the kernel is
214-
launch-overhead-dominated. The grouped MMA wins everywhere except that
215-
one moe_gu M=1 case.
216-
217-
### Why it still exists
218-
219-
1. It uses the flat data format (from `quantize_kbit` directly), no
220-
repack step. If you only store weights in flat format, the grouped
221-
scalar is the only MoE option at M=1-4.
222-
2. The moe_gu M=1 win is small but real in the most common workload
223-
(single-user decode). Over thousands of layers, 0.3 us adds up.
224-
3. It provides a correctness cross-check against the grouped MMA.
225-
226-
---
227-
228195
## Remaining optimization opportunities
229196

230197
### 1. CUDA Graphs for hybrid path (medium impact, low effort)

kbit-kernel-spec.md

Lines changed: 46 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,24 @@ than fp16 is at ~16 concurrent users.
122122

123123
---
124124

125-
## Five-kernel strategy
125+
## Four-kernel strategy
126126

127127
Each kernel covers a range of M where it has a structural advantage.
128-
The dispatch logic selects the best kernel per (layer_type, M) pair.
128+
The dispatch logic (`kbit_linear`, `kbit_expert_linear`) selects the
129+
best kernel per (layer_type, M) pair. All kernels read tiled format
130+
(from repack_kbit) with E4M4 absmax.
129131

130-
| Kernel | M range | Layer types | Data format |
131-
|--------|---------|-------------|-------------|
132-
| 1. Scalar GEMV | 1-4 | Dense, attention | Flat (quantize_kbit), float32 absmax |
133-
| 2. MMA dequant | 5-16 | Dense, attention | Tiled (repack_kbit), E4M4 absmax |
134-
| 3. Dequant + cuBLAS | 17+ | Dense, attention | Flat -> fp16 |
135-
| 4. Grouped scalar GEMV | 1-4 | MoE experts | Flat (quantize_kbit), float32 absmax |
136-
| 5. Grouped MMA | 1+ | MoE experts | Tiled (repack_kbit), E4M4 absmax |
132+
| Kernel | M range | Layer types | Dispatch function |
133+
|--------|---------|-------------|-------------------|
134+
| 1. Scalar GEMV | 1-4 | Dense, attention | `kbit_linear` |
135+
| 2. MMA dequant | 5-16 | Dense, attention | `kbit_linear` |
136+
| 3. Dequant + cuBLAS | 17+ | Dense, attention | `kbit_linear` |
137+
| 4. Grouped MMA | 1-16 | MoE experts | `kbit_expert_linear` |
137138

138-
Why five kernels instead of one:
139+
For MoE at max_M > 16, `kbit_expert_linear` falls back to per-expert
140+
dequant + cuBLAS matmul (no dedicated kernel needed).
141+
142+
Why four kernels instead of one:
139143
- At M=1, tensor cores waste 94% of their compute (m16n8k16 pads 15
140144
zero rows). A scalar kernel that avoids MMA entirely wins by 3-5x.
141145
- At M=5-16, MMA utilization rises to 31-100%. The 3.2x data
@@ -147,10 +151,6 @@ Why five kernels instead of one:
147151
its compute pipeline.
148152
- MoE experts launched individually waste 88-97% of SMs. Grouping
149153
all active experts into one kernel launch solves this.
150-
- The grouped scalar GEMV and grouped MMA serve complementary roles:
151-
scalar wins at M=1-4 for moe_gu (K=2048, N=512) where its C=1
152-
grid gives better parallelism; grouped MMA wins at all M for
153-
moe_dn (K=512, N=2048) and at M>4 for moe_gu.
154154

155155
**Practical importance (from workload analysis in `token_analysis.md`):**
156156

@@ -194,17 +194,18 @@ range falls in the gap between these modes.
194194
- No shared memory for B data, no cp.async, no split-K
195195

196196
**Data format:**
197-
- B_packed: flat from `quantize_kbit``[N * num_k_blocks * k]` uint32
198-
- B_absmax: flat float32 — `[N * num_k_blocks]`
199-
- No repack step needed
197+
- B_packed: tiled from `repack_kbit` — tiles of `[TILE_N × KB_PER_TILE × k]` uint32
198+
- B_absmax: tiled uint8 E4M4 — tiles of `[TILE_N × KB_PER_TILE]`
199+
- Supports both flat and tiled layouts via `TILED` template bool
200+
- Flat layout preserved for standalone use; tiled layout used by dispatch
200201

201202
**Inner loop (V8):**
202203

203204
Each thread strides through quantization blocks along K:
204205
```
205206
for each quant block (stride 64):
206207
load k bit-plane words (vectorized: int2 for k=2, int4 for k=4)
207-
load float32 absmax
208+
load absmax (E4M4 → float decode)
208209
209210
for sub = 0..3: // 4 groups of 8 elements
210211
load A[m, k_base + sub*8 .. +7] via int4 (8 fp16 values)
@@ -329,12 +330,13 @@ the MMA dequant kernel takes ~68 us (instruction-limited, only 1.3%
329330
of execution is MMA). A fused dequant kernel would take ~5 us for
330331
this shape, so dequant + cuBLAS ~27 us would beat 68 us.
331332

332-
**Dequant kernel** (`kDequantizeBlockwise_kbit_vec`): a single CUDA
333-
kernel that reads k-bit packed data + absmax and writes fp16 output.
334-
Templated on absmax type: float32 (from `quantize_kbit` directly),
335-
uint8 E4M4, or fp16. The float32 absmax path was added to eliminate
336-
a previous Python-side E4M4 conversion that launched ~15 PyTorch
337-
elementwise kernels (~800 us). Now it is a single kernel launch.
333+
**Dequant kernel:** Two variants:
334+
- `kDequantizeBlockwise_kbit_vec`: reads flat layout (from quantize_kbit)
335+
- `kDequantizeBlockwise_kbit_tiled`: reads tiled layout (from repack_kbit)
336+
337+
Both are templated on absmax type (uint8 E4M4, fp16, float32) and
338+
output type (fp16, bf16, float32). The tiled variant is used by
339+
`kbit_linear` dispatch for the M>16 dequant+cuBLAS path.
338340

339341
Dequant GPU kernel times (ncu-measured, k=4):
340342

@@ -352,51 +354,15 @@ to the matmul. At M>=64, dequant+cuBLAS wins because cuBLAS scales
352354
efficiently while MMA is instruction-limited. The crossover is
353355
M=32-64 depending on shape.
354356

355-
**Data format:** Uses flat layout (same as scalar GEMV). The
356-
`dequantize_kbit` launcher handles float32, uint8 E4M4, and fp16
357-
absmax via the `_KBIT_ABSMAX_SUFFIX` dispatch map.
358-
359-
---
360-
361-
## 4. Grouped scalar GEMV (`kbit_grouped_scalar_gemv`)
362-
363-
**Location:** `ops.cu` (search for `kbit_grouped_scalar_gemv`)
364-
365-
**Operation:** For each expert e: C_e[M_e, N] = A_e[M_e, K] * W_e^T,
366-
all experts in one kernel launch.
367-
368-
**Architecture (V8):**
369-
- 64 threads (2 warps), one output column per block (C=1)
370-
- Grid = (N, num_experts) — Y-dimension indexes experts
371-
- `__launch_bounds__(64, 24)` for M<=2, `__launch_bounds__(64, 16)` for M>2
372-
- M_VAL dispatch (1/2/3/4 templates)
373-
374-
**Data format:**
375-
- B_packed_all: flat from `quantize_kbit` — concatenated per-expert,
376-
each `[N * num_k_blocks * k]` uint32 (truncated to exact size)
377-
- B_absmax_all: flat float32 — concatenated per-expert,
378-
each `[N * num_k_blocks]` float32 (truncated to exact size)
379-
- No repack step needed. Uses same flat layout as the dense scalar GEMV.
380-
381-
**Inner loop:** Identical to the dense scalar GEMV (V8): vectorized
382-
int4 A loads, 4-group sub-loop of 8 elements, shuffle codebook lookup.
383-
The only difference is per-expert pointer arithmetic using
384-
`expert_offsets[expert_id]` to find each expert's A, B, and C regions.
385-
386-
**Why grouped scalar wins for moe_gu (K=2048, N=512) at M<=4:**
387-
With C=1, the grid is N × num_experts = 512 × 8 = 4096 blocks. This
388-
gives full SM utilization (32 blocks/SM). The grouped MMA at this shape
389-
has far fewer blocks due to tiling overhead.
390-
391-
**Quantize_kbit padding:** `quantize_kbit` appends a small padding
392-
(4 packed words + 1 absmax) to each expert's output. The test and
393-
benchmark helpers truncate each expert's data to the exact expected
394-
size before concatenation, so the kernel's arithmetic indexing
395-
(`expert_id * N * num_k_blocks * K_BITS`) works correctly.
357+
**Data format:** The flat variant (`dequantize_kbit`) reads flat layout
358+
from `quantize_kbit`. The tiled variant (`dequantize_kbit_tiled`) reads
359+
tiled layout from `repack_kbit`, used by `kbit_linear` dispatch.
360+
Both handle float32, uint8 E4M4, and fp16 absmax via the
361+
`_KBIT_ABSMAX_SUFFIX` dispatch map.
396362

397363
---
398364

399-
## 5. Grouped MMA (`kbit_grouped_gemm_prod`)
365+
## 4. Grouped MMA (`kbit_grouped_gemm_prod`)
400366

401367
**Location:** `ops.cu` (search for `kbit_grouped_gemm_prod`)
402368

@@ -486,33 +452,31 @@ targets < 10 us.
486452

487453
## Data formats
488454

489-
Two formats exist, and which kernel uses which matters:
455+
All inference kernels read **tiled format** (from `repack_kbit`).
456+
Flat format exists only as the intermediate output of `quantize_kbit`
457+
before repacking.
490458

491-
**Flat (from `quantize_kbit`):**
459+
**Flat (from `quantize_kbit`) — intermediate only:**
492460
- B_packed: `[N * num_k_blocks * k]` uint32, row-major per column
493-
- B_absmax: `[N * num_k_blocks]` float32
494-
- No preprocessing. Used by: scalar GEMV, grouped scalar GEMV,
495-
dequant kernel.
461+
- B_absmax: `[N * num_k_blocks]` float32 or uint8 E4M4
462+
- Used only during quantization. Converted to tiled by `repack_kbit`
463+
at model load time, then discarded.
496464

497-
**Tiled (from `repack_kbit`):**
465+
**Tiled (from `repack_kbit`) — runtime format:**
498466
- B_packed: reorganized into `[k_tiles * n_tiles * TILE_N * B_COL_WORDS]`
499467
for coalesced cp.async loads per tile
500468
- B_absmax: E4M4-encoded uint8, same tiled layout
501-
- Requires a one-time repack pass. Used by: MMA kernel, grouped MMA
502-
kernel.
469+
- Used by: scalar GEMV, MMA kernel, grouped MMA kernel,
470+
tiled dequant kernel (for dequant+cuBLAS path).
503471

504472
E4M4 encodes each float32 absmax as a single byte (4-bit exponent +
505473
4-bit mantissa). Decode is branchless: `ldexp(mantissa, exponent-bias)`.
506474
This saves 4x bandwidth for absmax reads but adds a decode step in
507475
the inner loop.
508476

509-
**Note:** The grouped scalar GEMV and grouped MMA use different data
510-
formats. The grouped scalar GEMV uses flat layout with float32 absmax
511-
(same as the dense scalar GEMV), while the grouped MMA uses tiled
512-
layout with E4M4 absmax (same as the dense MMA). This means MoE
513-
expert weights must be stored in both formats if both kernels are used
514-
in the dispatch, or a runtime conversion must happen. Currently the
515-
benchmark prepares each format separately.
477+
The flat dequant kernel (`kDequantizeBlockwise_kbit_vec`) is still
478+
available for standalone use (e.g., debugging), but `kbit_linear`
479+
dispatch uses the tiled dequant (`kDequantizeBlockwise_kbit_tiled`).
516480

517481
---
518482

@@ -539,7 +503,7 @@ per element; for k=2, ~8 ops.
539503

540504
| GPU | SM | MMA instruction | Async MMA? | Kernel strategy |
541505
|-----|-----|-----------------|------------|----------------|
542-
| RTX 4090 | sm_89 | mma.sync | No | All 5 kernels as described |
506+
| RTX 4090 | sm_89 | mma.sync | No | All 4 kernels as described |
543507
| RTX 5090 | sm_120 | mma.sync (ext) | No | Same strategy, more SMs (192) |
544508
| H100/H200 | sm_90a | wgmma.mma_async | Yes | Could overlap dequant + MMA |
545509
| B200/GB200 | sm_100a | tcgen05.mma | Yes | Could overlap dequant + MMA |

summary.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ Base: `23f92e5` (feature/kbit-gemv-v8)
66
## What changed
77

88
All kbit kernels now use uint8 E4M4 absmax by default, replacing float32.
9-
A float16 absmax alternative path is available for scalar GEMV and grouped
10-
scalar GEMV if higher absmax precision is needed.
9+
A float16 absmax alternative path is available for scalar GEMV if higher
10+
absmax precision is needed.
1111

1212
### Kernel changes
1313

@@ -17,19 +17,18 @@ scalar GEMV if higher absmax precision is needed.
1717
re-encoding from float32.
1818
- **Scalar GEMV** (dense): `unsigned char*` absmax with `load_absmax<T>`
1919
decode. Templated on `ABSMAX_T` for uint8 (default) and float16.
20-
- **Grouped scalar GEMV** (MoE): Same treatment as dense scalar GEMV.
2120
- **MMA kernels** (dense + grouped): Already used uint8 E4M4 — no change.
2221
- **Dequantize**: Already supported uint8 — no change.
2322

2423
### Files modified (8)
2524

2625
- `csrc/ops.cu` — E4M4 encode/decode moved before quantize kernel,
2726
quantize writes uint8, repack accepts uint8, fp16abs template
28-
instantiations for scalar/grouped GEMV
27+
instantiations for scalar GEMV
2928
- `csrc/pythonInterface.cpp` — All wrappers updated for `unsigned char*`;
30-
added 16 extern C symbols for fp16abs scalar/grouped GEMV
29+
added extern C symbols for fp16abs scalar GEMV
3130
- `bitsandbytes/backends/cuda/ops.py` — uint8 allocation in quantize,
32-
absmax dtype routing in scalar/grouped GEMV dispatch
31+
absmax dtype routing in scalar GEMV dispatch
3332
- `bitsandbytes/_ops.py` — quantize_kbit fake op returns uint8
3433
- `bitsandbytes/functional.py` — Removed redundant Python-side E4M4 encode
3534
- `tests/test_scalar_gemv.py` — E4M4 decode in reference functions
@@ -62,10 +61,7 @@ RTX 4090, CUDA events timing, fp16.
6261
consistent direction. Run-to-run variance dominates. No measurable
6362
regression.
6463

65-
**Grouped scalar GEMV / MoE** (8 configs, 8 experts):
66-
- M≥2: within noise (±3%)
67-
- M=1: possible ~5% overhead from E4M4 decode cost being a larger fraction
68-
of the small per-warp workload. One outlier at +22% is likely noise.
64+
**MoE grouped MMA** (8 configs, 8 experts): No change (already uint8 E4M4).
6965

7066
**MMA kernels**: No change (already uint8 E4M4).
7167

0 commit comments

Comments
 (0)