Skip to content

Commit df55daf

Browse files
TimDettmersclaude
andcommitted
Add workload-weighted kernel analysis and vLLM deployment model
- token_analysis.md: workload analysis using 397 sessions of real token distributions. Single-user: M=1 decode is 80-84% of GEMM time. Multi-user vLLM simulation (1-64 users): bimodal M distribution (decode-only vs decode+prefill chunk), crossover at ~16 users. - token_distributions.json: per-turn frequency distributions for prefill and decode token counts (power-of-two buckets, sum to 1.0). - kbit-kernel-spec.md: updated dequant section (single kernel launch, ncu-measured times), added practical kernel importance table showing scalar GEMV dominates at 1-4 users, dq+cuBLAS at 16+, MMA has minimal impact in either regime. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b02ff66 commit df55daf

File tree

3 files changed

+270
-11
lines changed

3 files changed

+270
-11
lines changed

kbit-kernel-spec.md

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ The batch size M seen by each kernel varies:
110110
- **M=1-32+**: dense layers (full batch)
111111
- **M=32-512+**: prefill / prompt processing
112112

113+
See `token_analysis.md` for a detailed workload analysis using real
114+
token distributions from 397 Claude Code sessions. The analysis shows
115+
that in single-user inference, M=1 decode accounts for 80-84% of total
116+
GEMM time. In multi-user vLLM serving, the M distribution is bimodal
117+
(M=num_users for decode-only iterations, M=num_users+chunk for prefill
118+
iterations), and the crossover where quantized kernels become slower
119+
than fp16 is at ~16 concurrent users.
120+
113121
---
114122

115123
## Four-kernel strategy
@@ -137,6 +145,24 @@ Why four kernels instead of one:
137145
- MoE experts launched individually waste 88-97% of SMs. Grouping
138146
all active experts into one kernel launch solves this.
139147

148+
**Practical importance (from workload analysis in `token_analysis.md`):**
149+
150+
In real deployments, the M distribution is bimodal — not uniform. With
151+
vLLM continuous batching, iterations are either pure-decode (M=num_users)
152+
or decode+prefill (M=num_users+chunk_size). The MMA kernel's M=5-16
153+
range falls in the gap between these modes.
154+
155+
| Scenario | Scalar share | MMA share | dq+cuBLAS share |
156+
|----------|-------------|-----------|-----------------|
157+
| 1 user | 87% | 0% | 13% |
158+
| 4 users | 59% | 0% | 41% |
159+
| 8 users | 0% | 45% | 55% |
160+
| 16 users | 0% | 24% | 76% |
161+
| 32+ users | 0% | 6% | 94% |
162+
163+
Optimization priority: scalar GEMV (1-4 users) > dequant overhead
164+
reduction (16+ users) > MMA kernel (8-16 users only, narrow range).
165+
140166
---
141167

142168
## 1. Scalar GEMV (`kbit_scalar_gemv`)
@@ -306,20 +332,32 @@ the MMA dequant kernel takes ~68 us (instruction-limited, only 1.3%
306332
of execution is MMA). A fused dequant kernel would take ~5 us for
307333
this shape, so dequant + cuBLAS ~27 us would beat 68 us.
308334

309-
**Current dequant implementation is not fused.** `dequantize_kbit`
310-
dispatches ~15 PyTorch elementwise kernels per call, giving a constant
311-
~800 us overhead regardless of shape. This makes dequant + cuBLAS
312-
non-competitive at M<64. A fused dequant CUDA kernel is needed for
313-
strategy 3 to be viable.
335+
**Dequant kernel** (`kDequantizeBlockwise_kbit_vec`): a single CUDA
336+
kernel that reads k-bit packed data + absmax and writes fp16 output.
337+
Templated on absmax type: float32 (from `quantize_kbit` directly),
338+
uint8 E4M4, or fp16. The float32 absmax path was added to eliminate
339+
a previous Python-side E4M4 conversion that launched ~15 PyTorch
340+
elementwise kernels (~800 us). Now it is a single kernel launch.
341+
342+
Dequant GPU kernel times (ncu-measured, k=4):
343+
344+
| Shape | Elements | Kernel time |
345+
|-------|----------|-------------|
346+
| gateup/down | 10.5M | ~30 us |
347+
| Q/O | 8.4M | ~25 us |
348+
| KV | 1.0M | ~5 us |
349+
350+
Times scale linearly with element count and k.
314351

315-
The crossover point depends on shape. For DRAM-bound shapes (Llama3-8B
316-
gate/up at 4096x14336), the MMA dequant kernel wins at 1.5x over
317-
cuBLAS because the 3.2x bandwidth savings dominate. For L2-resident
318-
shapes (MoE experts, small dense layers), cuBLAS wins because the
319-
kernel is instruction-limited, not bandwidth-limited.
352+
**Crossover vs MMA:** At M<=16, MMA beats dequant+cuBLAS on most
353+
shapes because the fixed dequant cost (~25-30 us) is large relative
354+
to the matmul. At M>=64, dequant+cuBLAS wins because cuBLAS scales
355+
efficiently while MMA is instruction-limited. The crossover is
356+
M=32-64 depending on shape.
320357

321358
**Data format:** Uses flat layout (same as scalar GEMV). The
322-
`dequantize_kbit` launcher handles both uint8 E4M4 and float32 absmax.
359+
`dequantize_kbit` launcher handles float32, uint8 E4M4, and fp16
360+
absmax via the `_KBIT_ABSMAX_SUFFIX` dispatch map.
323361

324362
---
325363

token_analysis.md

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Claude Code Token Analysis
2+
3+
## Session data location
4+
5+
Session JSONL files are stored at:
6+
```
7+
~/.claude/projects/<project-path>/<session-id>.jsonl
8+
```
9+
10+
Each file contains one JSON object per line with types: `user`, `assistant`, `system`, `progress`, `file-history-snapshot`.
11+
12+
## Methodology
13+
14+
### Input tokens (prefill)
15+
16+
Input = user prompts + tool results. These are measured from `user`-type messages in the JSONL:
17+
- `content[].type == "text"` entries give user prompt text
18+
- `content[].type == "tool_result"` entries give tool outputs (file reads, grep, bash)
19+
20+
Token count estimated at chars/4. System prompt, system injections, and the model's own prior output re-read as context are excluded — we only count new content the user/tools provide.
21+
22+
### Generated tokens (decode)
23+
24+
Generated = `output_tokens` from the `usage` field on `assistant`-type messages. This includes all model generation: text responses, tool call arguments, and thinking tokens (thinking content is encrypted so can't be separated).
25+
26+
### Per-turn grouping
27+
28+
A "turn" = one user message + all assistant API calls until the next user message. A single user turn may trigger multiple API calls (model calls a tool, gets result, calls another tool, etc.). Input for a turn = content in that user message. Output for a turn = sum of `output_tokens` across all API calls in that turn.
29+
30+
### Histogram bucketing
31+
32+
Values are bucketed to nearest power of two: `2^round(log2(n))`.
33+
34+
## Aggregate results: 397 sessions, 25,162 user turns
35+
36+
Data collected from 472 session files across all projects (75 empty/skipped). 41,537 total API calls.
37+
38+
| | Est. tokens |
39+
|---|---:|
40+
| Input (prefill) | ~31.8M |
41+
| Generated (decode) | ~2.3M |
42+
| **Ratio** | **13.7:1 input to output** |
43+
44+
### Frequency distributions
45+
46+
Per-turn frequency distributions (summing to 1.0) are stored in `token_distributions.json`. The file contains two distributions:
47+
48+
- `input_tokens_per_turn.freq` — estimated prefill tokens per user turn (user text + tool results). 24,155 non-empty turns.
49+
- `generated_tokens_per_turn.freq` — decode tokens per user turn (from API `output_tokens`). 20,911 non-empty turns.
50+
51+
Keys are power-of-two bucket sizes (as strings), values are frequencies.
52+
53+
### Interpretation
54+
55+
- Input peaks at 16-32 tokens (short prompts, small tool results) with a flat tail through 2048. Reflects a mix of user typing (small) and tool results (variable).
56+
- Output is bimodal: peaks at 2 tokens (20%, single short tool call) and 32 tokens (19%, tool call with moderate argument). Text responses and code blocks (128-2048) account for ~17% of turns.
57+
- Heavy generation (>4096 tokens) is rare (<0.5% of turns).
58+
59+
## Kernel performance weighted by workload
60+
61+
The token distributions in `token_distributions.json` serve as a workload model for estimating which GEMM kernels matter most in practice. The key mapping: **input tokens per turn = prefill M** (new tokens processed in a single forward pass with KV cache), **generated tokens per turn = number of decode steps at M=1** (or M=batch_size in multi-user serving).
62+
63+
### Single-user inference (M=1 decode)
64+
65+
In single-user autoregressive generation, each turn involves:
66+
- **1 prefill pass** at M = input_tokens (prompt/tool results, distributed by `input_tokens_per_turn`)
67+
- **N decode passes** at M = 1, where N is the number of generated tokens (distributed by `generated_tokens_per_turn`)
68+
69+
The average generated tokens per turn is ~114. So a typical turn has 1 prefill pass + 114 decode passes. Even though large prefills are individually expensive (a single M=32768 pass costs ~23,000 us/layer), they are rare enough (~1.4% frequency) that decode at M=1 dominates total wall-clock time at **80-84%** across k=2..5.
70+
71+
Per-layer time breakdown (k=4, Qwen3-Coder-Next shapes):
72+
73+
| Component | Time/turn/layer | % of total |
74+
|-----------|----------------:|------------|
75+
| Decode (114 steps x 55.6 us) | 6,347 us | 83.4% |
76+
| Prefill (distributed) | 1,260 us | 16.6% |
77+
78+
The scalar GEMV kernel (M=1) is faster than fp16 cuBLAS because it reads 3-4x less data (k-bit compressed weights vs fp16). Overall weighted slowdown vs fp16: **0.57x** (43% faster) at k=4.
79+
80+
### Multi-user serving with vLLM
81+
82+
Production deployments use continuous batching (vLLM), which changes the M distribution fundamentally. The vLLM V1 scheduler (`vllm/v1/core/sched/scheduler.py`) works as follows:
83+
84+
1. **Decode-first**: all running (decoding) requests are scheduled first, each contributing 1 token. M starts at num_decoding_users.
85+
2. **Chunked prefill**: remaining token budget is used for at most one prefill chunk from a waiting request. Default chunk size is `max_model_len * 0.04` (e.g., 1280 for 32K context, 5120 for 128K).
86+
3. **Token budget cap**: total tokens per step is bounded by `max_num_batched_tokens` (default 8192).
87+
4. **One partial prefill at a time**: `max_num_partial_prefills` defaults to 1.
88+
89+
This creates a **bimodal M distribution**: iterations are either pure-decode (M = num_users) or decode + prefill chunk (M = num_users + chunk_size). The MMA kernel's effective range (M=8-32) falls in the gap between these modes and is rarely used.
90+
91+
Simulation results (k=4, chunk_size=512, token distributions from `token_distributions.json`):
92+
93+
| Users | Avg M | Decode-only iters | Dominant kernel | vs fp16 |
94+
|------:|------:|------------------:|-----------------|--------:|
95+
| 1 | 8 | 98.6% | scalar (87%) | 0.57x |
96+
| 4 | 41 | 92.6% | scalar (59%) + dq+cuBLAS (41%) | 0.76x |
97+
| 8 | 77 | 86.1% | MMA (45%) + dq+cuBLAS (55%) | 0.85x |
98+
| 16 | 163 | 70.2% | dq+cuBLAS (76%) | 1.00x |
99+
| 32 | 364 | 30.9% | dq+cuBLAS (93%) | 1.17x |
100+
| 64 | 495 | 5.1% | dq+cuBLAS (98%) | 1.23x |
101+
102+
The crossover where quantized kernels become slower than fp16 is at **~16 concurrent users**. Below that, bandwidth savings from k-bit compression outweigh the dequant overhead. Above that, the dequant cost (~30 us/shape at k=4) dominates because most iterations include a large prefill chunk where cuBLAS is highly efficient.
103+
104+
### Optimization priorities
105+
106+
The analysis identifies two regimes with different optimization targets:
107+
108+
**1-4 users (agents, local inference, code assistants):**
109+
The scalar GEMV at M=1..4 accounts for 59-87% of total GEMM time. This kernel is already bandwidth-bound and faster than fp16. Further optimization (better ILP in the M-loop, wider vector loads) has the highest leverage. The dq+cuBLAS path handles the occasional prefill chunk (~41% of time at 4 users) with moderate overhead (1.25x vs fp16). The MMA kernel is effectively unused.
110+
111+
**16+ users (serving, API endpoints):**
112+
dq+cuBLAS dominates (75-98% of time). The ~30 us dequant overhead per shape at k=4 is the primary cost. Reducing this — through a faster dequant kernel, fusing dequant into the matmul, or accepting float32 absmax to skip format conversion — would directly reduce the 1.17-1.23x slowdown vs fp16.
113+
114+
**The MMA kernel has minimal impact in either regime.** Its effective range (M=8-32) corresponds to pure-decode batches at 8-32 users, which is a shrinking slice of iterations as user count grows. At 4 users, M never reaches the MMA range. At 32 users, only 31% of iterations are pure-decode at M=32, and MMA accounts for just 5.8% of total weighted time.
115+
116+
## Script
117+
118+
```python
119+
import json, math
120+
121+
SESSION = "~/.claude/projects/<project>/<session-id>.jsonl"
122+
123+
with open(SESSION) as f:
124+
lines = [json.loads(l) for l in f]
125+
126+
timeline = [l for l in lines if l.get('type') in ('user', 'assistant')]
127+
128+
turns = []
129+
for i, msg in enumerate(timeline):
130+
if msg['type'] != 'user':
131+
continue
132+
content = msg.get('message', {}).get('content', '')
133+
input_chars = 0
134+
if isinstance(content, list):
135+
for c in content:
136+
if c.get('type') == 'text':
137+
input_chars += len(c.get('text', ''))
138+
elif c.get('type') == 'tool_result':
139+
rc = c.get('content', '')
140+
if isinstance(rc, str):
141+
input_chars += len(rc)
142+
elif isinstance(rc, list):
143+
input_chars += sum(len(json.dumps(x)) for x in rc)
144+
elif isinstance(content, str):
145+
input_chars += len(content)
146+
147+
total_output = 0
148+
for j in range(i + 1, len(timeline)):
149+
if timeline[j]['type'] == 'user':
150+
break
151+
if timeline[j]['type'] == 'assistant':
152+
total_output += timeline[j]['message']['usage'].get('output_tokens', 0)
153+
154+
turns.append({'input_est': input_chars // 4, 'output': total_output})
155+
156+
def bucket(n):
157+
if n <= 0: return 0
158+
return 2 ** round(math.log2(max(n, 1)))
159+
160+
for label, key in [("Input", "input_est"), ("Generated", "output")]:
161+
vals = [t[key] for t in turns if t[key] > 0]
162+
buckets = {}
163+
for v in vals:
164+
b = bucket(v)
165+
buckets[b] = buckets.get(b, 0) + 1
166+
mx = max(buckets.values())
167+
print(f"\n=== {label} tokens per turn ({len(vals)} turns) ===")
168+
for b in sorted(buckets):
169+
bar = "#" * max(1, round(buckets[b] / mx * 40))
170+
print(f"{b:>8} {buckets[b]:>5} {bar}")
171+
```

token_distributions.json

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
{
2+
"description": "Token count frequency distributions across 397 Claude Code sessions (472 files, 75 empty). Buckets are nearest power of two. Frequencies sum to 1.0.",
3+
"sessions": 397,
4+
"input_tokens_per_turn": {
5+
"description": "Estimated input tokens per user turn (user text + tool results, chars/4). Only non-empty turns included.",
6+
"num_turns": 24229,
7+
"freq": {
8+
"1": 0.00388,
9+
"2": 0.006108,
10+
"4": 0.049858,
11+
"8": 0.062198,
12+
"16": 0.152256,
13+
"32": 0.167733,
14+
"64": 0.100045,
15+
"128": 0.093566,
16+
"256": 0.085022,
17+
"512": 0.086219,
18+
"1024": 0.064221,
19+
"2048": 0.050147,
20+
"4096": 0.035742,
21+
"8192": 0.014157,
22+
"16384": 0.013785,
23+
"32768": 0.013703,
24+
"65536": 0.001279,
25+
"131072": 4.1e-05,
26+
"262144": 4.1e-05
27+
}
28+
},
29+
"generated_tokens_per_turn": {
30+
"description": "Output tokens per user turn (from API usage.output_tokens, includes text + tool calls + thinking). Only non-empty turns included.",
31+
"num_turns": 20946,
32+
"freq": {
33+
"1": 0.065502,
34+
"2": 0.201518,
35+
"4": 0.109902,
36+
"8": 0.100449,
37+
"16": 0.112766,
38+
"32": 0.19302,
39+
"64": 0.04144,
40+
"128": 0.055619,
41+
"256": 0.048219,
42+
"512": 0.034947,
43+
"1024": 0.022057,
44+
"2048": 0.010312,
45+
"4096": 0.003533,
46+
"8192": 0.000668,
47+
"16384": 4.8e-05
48+
}
49+
}
50+
}

0 commit comments

Comments
 (0)