Skip to content

Commit 6e85f4d

Browse files
TimDettmersclaude
andcommitted
Add final benchmarking report with deployment analysis and dispatch overhead findings
Consolidates kernel benchmark results across all k values (k=2..5), deployment speedup projections for Qwen3-Coder-Next 70B under single-user and 4-user vLLM serving, and documents the Python dispatch overhead issue (25 us per call from torch.library) with proposed mitigations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7e0063c commit 6e85f4d

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

benchmarking-report.md

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# K-bit kernel benchmarking report
2+
3+
RTX 4090 (128 SMs, sm_89), Qwen3-Coder-Next 70B (MoE, hidden_dim=2048).
4+
All kernel times are NCU `gpu__time_duration.avg` unless stated otherwise.
5+
6+
## Kernel dispatch
7+
8+
Five kernels cover the full inference workload. Dispatch selects the fastest
9+
kernel per (layer_type, M) pair:
10+
11+
| Kernel | M range | Layers | Status |
12+
|--------|---------|--------|--------|
13+
| Scalar GEMV | 1-4 | Dense + attention | Done (V8), 1.5-1.9x faster than fp16 at M=1 |
14+
| MMA dequant | 5-16 | Dense + attention | Done, ~1.0-1.3x vs fp16 |
15+
| Dequant + cuBLAS | 17+ | Dense + attention | Done, ~0.95-1.0x vs fp16 |
16+
| Grouped scalar GEMV | 1-4 | MoE experts | Done, competitive with fp16 |
17+
| Grouped MMA | 5+ | MoE experts | Done, competitive with fp16 |
18+
19+
## Per-shape speedups at M=1 (decode, dominant workload)
20+
21+
Best kbit kernel vs cuBLAS fp16, all shapes per transformer block:
22+
23+
| Shape | k=2 | k=3 | k=4 | k=5 |
24+
|-------|-----|-----|-----|-----|
25+
| gateup (2048x5120) | 2.47x | 2.17x | 1.76x | 1.58x |
26+
| down (5120x2048) | 2.05x | 1.84x | 1.57x | 1.42x |
27+
| Q (2048x4096) | 1.90x | 1.67x | 1.43x | 1.28x |
28+
| O (4096x2048) | 2.23x | 2.01x | 1.72x | 1.54x |
29+
| KV (2048x512) | 1.86x | 1.65x | 1.41x | 1.27x |
30+
| moe_gu (2048x512, 8 experts) | ~1.03x | ~1.05x | ~1.03x | ~0.98x |
31+
| moe_dn (512x2048, 8 experts) | ~1.10x | ~1.08x | ~1.05x | ~1.00x |
32+
33+
Dense layers see large speedups because the scalar GEMV reads 2-5x less
34+
data (k-bit compressed weights vs fp16). MoE layers are roughly at parity
35+
because the grouped kernel inner loop has not yet received the V8
36+
optimizations (vectorized A loads, 2-warp config).
37+
38+
## Model size per k
39+
40+
Qwen3-Coder-Next 70B total weight parameters: ~70B.
41+
42+
| k | Bits/param | Model size (weights only) | vs fp16 (140 GB) |
43+
|---|-----------|--------------------------|-------------------|
44+
| 2 | 2 | ~17.5 GB | 8.0x smaller |
45+
| 3 | 3 | ~26.3 GB | 5.3x smaller |
46+
| 4 | 4 | ~35.0 GB | 4.0x smaller |
47+
| 5 | 5 | ~43.8 GB | 3.2x smaller |
48+
49+
At k=2, the entire 70B model fits in a single RTX 4090 (24 GB VRAM) with
50+
room for KV cache. At k=4, it requires ~35 GB which needs multi-GPU or an
51+
80 GB card.
52+
53+
## Deployment speedups (NCU kernel-only, single-user decode)
54+
55+
Single-user inference is dominated by M=1 decode (80-84% of total GEMM
56+
time, from workload analysis in `token_analysis.md`). The weighted per-block
57+
speedup:
58+
59+
| k | Decode speedup (M=1) | Weighted overall (decode + prefill) |
60+
|---|---------------------|-------------------------------------|
61+
| 2 | ~1.90x | ~1.58x |
62+
| 3 | ~1.70x | ~1.45x |
63+
| 4 | ~1.50x | ~1.30x |
64+
| 5 | ~1.35x | ~1.18x |
65+
66+
Prefill uses dequant + cuBLAS, which is slightly slower than pure fp16.
67+
But prefill is infrequent: a typical turn has 1 prefill pass + 114 decode
68+
steps, so the decode speedup dominates.
69+
70+
## Deployment speedups (NCU kernel-only, 4-user vLLM)
71+
72+
With 4 concurrent users in vLLM continuous batching, the M distribution is
73+
bimodal: M=4 for decode-only iterations (92.6% of iterations) and M=4+chunk
74+
for decode+prefill iterations. The scalar kernel handles 59% of GEMM time,
75+
dequant+cuBLAS handles 41%.
76+
77+
| k | 4-user weighted speedup |
78+
|---|------------------------|
79+
| 2 | ~1.58x |
80+
| 3 | ~1.40x |
81+
| 4 | ~1.25x |
82+
| 5 | ~1.12x |
83+
84+
The crossover where quantized kernels become slower than fp16 is at ~16
85+
concurrent users. Below that, bandwidth savings from k-bit compression
86+
outweigh the dequant overhead. Above that, the dequant cost per shape
87+
dominates because most iterations include a large prefill chunk where cuBLAS
88+
is highly efficient.
89+
90+
## Dequant kernel NCU times (bandwidth model at 815 GB/s)
91+
92+
The dequant kernel (`kDequantizeBlockwise_kbit_vec`) reads k-bit packed data
93+
plus absmax and writes fp16 output. Times scale with element count and k:
94+
95+
| Shape | Elements | k=2 | k=3 | k=4 | k=5 |
96+
|-------|----------|-----|-----|-----|-----|
97+
| gateup/down | 10.5M | 29.3 us | 30.5 us | 31.8 us | 33.1 us |
98+
| Q/O | 8.4M | 23.5 us | 24.4 us | 26.1 us | 27.3 us |
99+
| KV | 1.0M | 2.9 us | 3.0 us | 3.2 us | 3.4 us |
100+
101+
k=2 is fastest because it reads only 0.25 bytes/element packed; k=5 reads
102+
0.625 bytes/element. The fp16 output write (2 bytes/element) dominates
103+
bandwidth regardless of k, which is why the spread is only ~15%.
104+
105+
## Issue: Python dispatch overhead in bitsandbytes custom ops
106+
107+
Profiled the per-call overhead of custom CUDA kernels (kbit dequant as the
108+
test case, but this applies to all ops going through `torch.library`). For
109+
a kernel that takes 26 us on-GPU (NCU), the CUDA events end-to-end time is
110+
51 us -- nearly 2x the kernel itself.
111+
112+
Breakdown of the ~25 us overhead:
113+
114+
```
115+
torch.ops dispatch routing: ~10 us (library registry lookup, dispatch key resolution)
116+
functional.py wrapper: ~9 us (argument reordering, out[:n] slice)
117+
torch._check x 4: ~5 us (runtime type/dtype assertions)
118+
torch.empty (16 MB output): ~4 us (allocator)
119+
CUDA driver launch: ~3 us (kernel submission)
120+
```
121+
122+
For comparison, calling the kernel directly through ctypes (bypassing
123+
`torch.library` entirely) measures 3.3 us overhead -- the raw CUDA driver
124+
launch cost. The remaining 22 us is pure Python/PyTorch framework overhead.
125+
126+
### Why this matters for deployment
127+
128+
At M=1 decode (the dominant workload), a typical Qwen3 transformer block
129+
has 7 weight matmul kernel launches. At 25 us overhead each, that is 175 us
130+
of pure dispatch overhead per block -- comparable to the total kernel
131+
compute time. For the dequant+cuBLAS path (M>16), each shape needs 2
132+
kernel launches (dequant + matmul), doubling the dispatch tax.
133+
134+
### Possible mitigations
135+
136+
1. **CUDA graphs**: capture the dispatch sequence and replay it,
137+
eliminating per-call Python overhead. Requires static shapes or
138+
shape-bucketed graphs. This is the standard production solution.
139+
2. **Direct ctypes dispatch**: bypass `torch.library` for hot-path ops.
140+
Reduces overhead from 25 us to 3 us. Loses `torch.compile`
141+
compatibility.
142+
3. **Fuse dequant into matmul**: eliminate the separate dequant kernel
143+
launch entirely for M>16. Requires a custom matmul kernel that reads
144+
k-bit weights directly (the MMA kernel already does this for M<=16).
145+
4. **Reduce `torch._check` calls**: the 4 runtime assertions add ~5 us.
146+
These could be gated behind a debug flag.
147+
5. **Eliminate argument reordering**: `functional.py` reorders arguments
148+
before calling `torch.ops`. Aligning the public API with the internal
149+
op signature would save ~9 us.
150+
151+
## Conclusions
152+
153+
1. **K-bit quantization provides significant speedups for low-concurrency
154+
serving.** At k=2, single-user decode is ~1.9x faster than fp16 while
155+
using 8x less memory. Even k=4 gives 1.5x decode speedup with 4x
156+
compression.
157+
158+
2. **The sweet spot is 1-4 concurrent users.** The scalar GEMV kernel
159+
dominates at this scale and is bandwidth-bound -- it directly benefits
160+
from reading less data. At 16+ users, prefill overhead erodes the
161+
advantage.
162+
163+
3. **Python dispatch overhead is the next bottleneck.** The 25 us per-call
164+
overhead nearly doubles the effective kernel time at M=1. Addressing
165+
this (via CUDA graphs, direct ctypes, or fusing ops) would improve
166+
end-to-end throughput by up to 1.5x on top of the current kernel
167+
speedups.
168+
169+
4. **MoE grouped kernels need V8 optimizations.** The grouped scalar GEMV
170+
currently matches fp16 but does not beat it. Porting the V8 inner loop
171+
(vectorized A loads, 2-warp config, M-dispatch) would bring it closer
172+
to the 1.5-1.9x speedups seen on dense layers.
173+
174+
5. **Lower k is strictly better for inference speed, not just model size.**
175+
k=2 is fastest at every M value because it reads the least data. The
176+
accuracy-speed tradeoff is the only reason to use higher k values.

0 commit comments

Comments
 (0)