Skip to content

Commit 14ee2e9

Browse files
TimDettmersclaude
andcommitted
bench: Add comprehensive VLM benchmark + fix CUDA graph timing methodology
- Add bench_kbit_vlm.py: sweeps all kernel variants (scalar GEMV, MMA, dequant+cuBLAS) with and without Hadamard rotation across VLM-relevant M values (1-1024) on Qwen3-Coder-Next 70B shapes, k=2..5. - Rewrite bench_hadamard.py to use batched graph replay: replay the captured graph N times within one event pair, then divide. This amortizes the ~14 us per-replay timing floor to <0.03 us, revealing true kernel execution times that were previously masked. - Update CLAUDE.md with benchmarking section: where to find scripts, how to run them, and the batched replay methodology. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fcfca9f commit 14ee2e9

File tree

3 files changed

+400
-74
lines changed

3 files changed

+400
-74
lines changed

CLAUDE.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,32 @@ pytest tests/ -v --tb=short -n 4
2323
```
2424

2525
Best practices, benchmark data, and known architecture-specific issues: `agents/testing_guide.md`
26+
27+
# Benchmarking
28+
29+
Benchmark scripts live in `benchmarks/`. The two kbit-specific ones:
30+
31+
- `bench_hadamard.py` — Hadamard rotation kernel + M=1 pipeline (rotation + scalar GEMV) vs cuBLAS FP16. Quick focused benchmark for the decode path.
32+
- `bench_kbit_vlm.py` — Comprehensive sweep across all VLM-relevant M values (1 to 1024), all kernel variants (scalar GEMV, MMA, dequant+cuBLAS), all k values (2-5), with and without Hadamard rotation. Qwen3-Coder-Next 70B shapes.
33+
34+
```bash
35+
# Quick M=1 decode benchmark
36+
python benchmarks/bench_hadamard.py
37+
38+
# Full VLM sweep (all M, all k)
39+
python benchmarks/bench_kbit_vlm.py
40+
41+
# Single k value, subset of M
42+
python benchmarks/bench_kbit_vlm.py --k 4 --m 1,4,16,256,1024
43+
44+
# Higher accuracy (more iterations)
45+
python benchmarks/bench_kbit_vlm.py --inner 1000 --outer 30
46+
```
47+
48+
## CUDA graph benchmarking methodology
49+
50+
Single graph replay has a ~14 us timing floor (on RTX 4090) that masks sub-14 us kernel differences. The benchmarks use **batched graph replay**: replay the graph N times within one event-timed region, then divide. This amortizes the per-replay overhead to ~14/N us per iteration.
51+
52+
The `--inner` flag controls N (replays per measurement). Default 500 gives ~0.03 us amortized overhead. Use `--inner 1000` for the highest accuracy when comparing kernels that differ by < 1 us.
53+
54+
`--outer` controls the number of measurements (default 15). The median is reported to reject outliers.

benchmarks/bench_hadamard.py

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1-
"""Benchmark for Hadamard rotation kernel and full kbit pipeline.
1+
"""Benchmark for Hadamard rotation kernel and kbit M=1 pipeline.
22
33
Measures:
4-
1. Rotation standalone: all block sizes × Qwen3 K values × M=1,4
5-
2. Full pipeline (rotate + kbit_scalar_gemv_tiled): Qwen3 dense shapes at M=1, k=2,3,4
4+
1. Rotation standalone: all block sizes x Qwen3 K values x M=1,4
5+
2. Full pipeline (rotate + kbit_scalar_gemv_tiled): Qwen3 dense shapes at M=1, k=2..5
66
3. cuBLAS FP16 baseline: same shapes
77
4. Speedup table: pipeline vs cuBLAS
88
9-
All timing via CUDA graph capture + replay for clean kernel-only measurements.
9+
Timing methodology:
10+
CUDA graph capture + batched replay. Each measurement replays the graph
11+
INNER times within a single event-timed region, then divides. This
12+
amortizes the ~14 us per-replay overhead down to negligible levels,
13+
revealing true kernel execution times. Median of OUTER measurements.
14+
15+
Usage:
16+
python benchmarks/bench_hadamard.py
17+
python benchmarks/bench_hadamard.py --inner 1000 --outer 30 # higher accuracy
1018
"""
1119

20+
import argparse
1221
import sys
1322

1423
import torch
@@ -22,9 +31,7 @@
2231
quantize_kbit,
2332
)
2433

25-
BLOCKSIZE = 32
26-
WARMUP = 50
27-
ITERS = 200
34+
ROTATION_BLOCK_SIZE = 64
2835

2936

3037
def create_normal_float_codebook(k: int) -> torch.Tensor:
@@ -35,14 +42,18 @@ def create_normal_float_codebook(k: int) -> torch.Tensor:
3542
return values.cuda()
3643

3744

38-
def bench_graph(fn, warmup=WARMUP, iters=ITERS):
39-
"""Time a function using CUDA graph capture + replay. Returns median time in us."""
40-
# Warm up on default stream
41-
for _ in range(warmup):
45+
def bench(fn, inner: int, outer: int) -> float:
46+
"""Batched CUDA graph replay timing. Returns median us per iteration.
47+
48+
Captures fn into a CUDA graph, then replays it `inner` times within a
49+
single CUDA event pair. The per-replay overhead (~14 us on RTX 4090)
50+
is amortized to ~14/inner us per iteration. Takes the median of `outer`
51+
such measurements.
52+
"""
53+
for _ in range(30):
4254
fn()
4355
torch.cuda.synchronize()
4456

45-
# Capture graph
4657
s = torch.cuda.Stream()
4758
s.wait_stream(torch.cuda.current_stream())
4859
with torch.cuda.stream(s):
@@ -55,27 +66,34 @@ def bench_graph(fn, warmup=WARMUP, iters=ITERS):
5566
fn()
5667
torch.cuda.synchronize()
5768

58-
# Warm up replay
59-
for _ in range(10):
69+
for _ in range(50):
6070
g.replay()
6171
torch.cuda.synchronize()
6272

63-
# Time replay
6473
times = []
65-
for _ in range(iters):
74+
for _ in range(outer):
6675
start = torch.cuda.Event(enable_timing=True)
6776
end = torch.cuda.Event(enable_timing=True)
6877
start.record()
69-
g.replay()
78+
for _ in range(inner):
79+
g.replay()
7080
end.record()
7181
torch.cuda.synchronize()
72-
times.append(start.elapsed_time(end) * 1000) # ms -> us
73-
82+
times.append(start.elapsed_time(end) * 1000 / inner) # ms -> us/iter
7483
times.sort()
75-
return times[len(times) // 2] # median
84+
return times[len(times) // 2]
85+
86+
87+
def prepare_kbit_weights(K_dim, N, k):
88+
"""Quantize random weights and repack for tiled access."""
89+
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
90+
codebook = create_normal_float_codebook(k)
91+
packed, absmax, _ = quantize_kbit(W, k=k, codebook=codebook)
92+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(packed, absmax, K_dim, N, k)
93+
return packed_tiled, absmax_tiled, codebook
7694

7795

78-
def bench_rotation_standalone():
96+
def bench_rotation_standalone(inner, outer):
7997
"""Benchmark rotation kernel standalone across block sizes and shapes."""
8098
print("=" * 70)
8199
print("1. ROTATION STANDALONE")
@@ -91,31 +109,20 @@ def bench_rotation_standalone():
91109
for K in k_values:
92110
for bs in block_sizes:
93111
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
94-
t = bench_graph(lambda: hadamard_rotate(A, block_size=bs))
95-
# BW: read + write = 2 * numel * 2 bytes (fp16)
112+
t = bench(lambda: hadamard_rotate(A, block_size=bs), inner, outer)
96113
bw = 2 * A.numel() * 2 / (t / 1e6) / 1e9
97-
print(f"{M:>4} {K:>6} {bs:>4} {t:>10.2f} {bw:>10.1f}")
114+
print(f"{M:>4} {K:>6} {bs:>4} {t:>10.3f} {bw:>10.1f}")
98115
print()
99116

100117

101-
def prepare_kbit_weights(K_dim, N, k):
102-
"""Quantize random weights and repack for tiled access."""
103-
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
104-
codebook = create_normal_float_codebook(k)
105-
packed, absmax, _ = quantize_kbit(W, k=k, codebook=codebook)
106-
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(packed, absmax, K_dim, N, k)
107-
return packed_tiled, absmax_tiled, codebook
108-
109-
110-
def bench_pipeline():
118+
def bench_pipeline(inner, outer):
111119
"""Benchmark full pipeline: rotate(A) + kbit_scalar_gemv."""
112120
print("=" * 70)
113121
print("2. FULL PIPELINE: rotate + kbit_scalar_gemv_tiled")
114122
print("=" * 70)
115123
print(f"{'M':>4} {'K':>6} {'N':>6} {'k':>2} {'Rotate(us)':>11} {'GEMV(us)':>9} {'Total(us)':>10} {'TFLOPS':>7}")
116124
print("-" * 65)
117125

118-
# Qwen3-Coder-Next 70B dense shapes
119126
shapes = [
120127
(1, 2048, 5120, "gate/up"),
121128
(1, 5120, 2048, "down"),
@@ -126,42 +133,41 @@ def bench_pipeline():
126133
(4, 5120, 2048, "down M=4"),
127134
]
128135

129-
for k in [2, 3, 4]:
136+
for k in [2, 3, 4, 5]:
130137
print(f"\n--- k={k} ---")
131138
for M, K_dim, N, label in shapes:
132139
packed_tiled, absmax_tiled, codebook = prepare_kbit_weights(K_dim, N, k)
133140
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
134141

135-
# Benchmark rotation alone
136142
A_copy = A.clone()
137-
t_rot = bench_graph(lambda: hadamard_rotate(A_copy, block_size=64))
143+
t_rot = bench(lambda: hadamard_rotate(A_copy, block_size=ROTATION_BLOCK_SIZE), inner, outer)
138144

139-
# Benchmark GEMV alone (tiled layout, pre-allocated output)
140145
out = torch.zeros(M, N, dtype=torch.float16, device="cuda")
141-
t_gemv = bench_graph(
146+
t_gemv = bench(
142147
lambda: torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
143148
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out
144-
)
149+
),
150+
inner,
151+
outer,
145152
)
146153

147-
# Benchmark combined
148154
def pipeline():
149-
hadamard_rotate(A_copy, block_size=64)
155+
hadamard_rotate(A_copy, block_size=ROTATION_BLOCK_SIZE)
150156
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
151157
A_copy, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out
152158
)
153159

154-
t_total = bench_graph(pipeline)
160+
t_total = bench(pipeline, inner, outer)
155161

156162
flops = 2 * M * K_dim * N
157163
tflops = flops / (t_total / 1e6) / 1e12
158164
print(
159-
f"{M:>4} {K_dim:>6} {N:>6} {k:>2} {t_rot:>11.2f} {t_gemv:>9.2f} "
160-
f"{t_total:>10.2f} {tflops:>7.3f} {label}"
165+
f"{M:>4} {K_dim:>6} {N:>6} {k:>2} {t_rot:>11.3f} {t_gemv:>9.3f} "
166+
f"{t_total:>10.3f} {tflops:>7.3f} {label}"
161167
)
162168

163169

164-
def bench_cublas_baseline():
170+
def bench_cublas_baseline(inner, outer):
165171
"""Benchmark cuBLAS FP16 GEMM for the same shapes."""
166172
print("\n" + "=" * 70)
167173
print("3. cuBLAS FP16 BASELINE")
@@ -184,16 +190,16 @@ def bench_cublas_baseline():
184190
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
185191
out = torch.empty(M, N, dtype=torch.float16, device="cuda")
186192

187-
t = bench_graph(lambda: torch.mm(A, W.t(), out=out))
193+
t = bench(lambda: torch.mm(A, W.t(), out=out), inner, outer)
188194
flops = 2 * M * K_dim * N
189195
tflops = flops / (t / 1e6) / 1e12
190-
print(f"{M:>4} {K_dim:>6} {N:>6} {t:>9.2f} {tflops:>7.3f}")
196+
print(f"{M:>4} {K_dim:>6} {N:>6} {t:>9.3f} {tflops:>7.3f}")
191197

192198

193-
def bench_speedup_table():
199+
def bench_speedup_table(inner, outer):
194200
"""Print a speedup comparison table: pipeline vs cuBLAS."""
195201
print("\n" + "=" * 70)
196-
print("4. SPEEDUP TABLE: kbit pipeline vs cuBLAS FP16")
202+
print("4. SPEEDUP TABLE: Rot + kbit GEMV vs cuBLAS FP16")
197203
print("=" * 70)
198204

199205
shapes = [
@@ -208,7 +214,7 @@ def bench_speedup_table():
208214
print(f"{'Shape':>20} {'k':>2} {'Pipeline(us)':>13} {'cuBLAS(us)':>11} {'Speedup':>8}")
209215
print("-" * 65)
210216

211-
for k in [2, 3, 4]:
217+
for k in [2, 3, 4, 5]:
212218
print(f"\n--- k={k} ---")
213219
for M, K_dim, N, label in shapes:
214220
packed_tiled, absmax_tiled, codebook = prepare_kbit_weights(K_dim, N, k)
@@ -217,40 +223,36 @@ def bench_speedup_table():
217223
out = torch.zeros(M, N, dtype=torch.float16, device="cuda")
218224
A_copy = A.clone()
219225

220-
# Pipeline: rotate + GEMV
221226
def pipeline():
222-
hadamard_rotate(A_copy, block_size=64)
227+
hadamard_rotate(A_copy, block_size=ROTATION_BLOCK_SIZE)
223228
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
224229
A_copy, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out
225230
)
226231

227-
t_pipe = bench_graph(pipeline)
228-
229-
# cuBLAS baseline
230-
t_cublas = bench_graph(lambda: torch.mm(A, W.t(), out=out))
232+
t_pipe = bench(pipeline, inner, outer)
233+
t_cublas = bench(lambda: torch.mm(A, W.t(), out=out), inner, outer)
231234

232235
speedup = t_cublas / t_pipe
233236
shape_str = f"{M}x{K_dim}x{N}"
234-
print(f"{shape_str:>20} {k:>2} {t_pipe:>13.2f} {t_cublas:>11.2f} {speedup:>7.2f}x {label}")
237+
print(f"{shape_str:>20} {k:>2} {t_pipe:>13.3f} {t_cublas:>11.3f} {speedup:>7.2f}x {label}")
235238

236239

237-
def bench_cuda_graph_capture():
238-
"""Verify that all benchmarks above were graph-captured (implicit from bench_graph).
239-
This just confirms the pipeline captures as a single graph explicitly."""
240-
print("\n" + "=" * 70)
241-
print("5. CUDA GRAPH CAPTURE VERIFICATION")
242-
print("=" * 70)
243-
print("All benchmarks above used CUDA graph capture + replay for timing.")
244-
print("If they produced numbers, graph capture succeeded for all operations.")
240+
def main():
241+
parser = argparse.ArgumentParser(description="Hadamard rotation + kbit M=1 pipeline benchmark")
242+
parser.add_argument("--inner", type=int, default=500, help="Graph replays per measurement (default: 500)")
243+
parser.add_argument("--outer", type=int, default=15, help="Measurements per benchmark (default: 15)")
244+
args = parser.parse_args()
245245

246-
247-
if __name__ == "__main__":
248246
print(f"GPU: {torch.cuda.get_device_name(0)}")
249247
print(f"CUDA: {torch.version.cuda}")
248+
print(f"Timing: batched graph replay ({args.inner} replays/measurement, median of {args.outer})")
250249
print()
251250

252-
bench_rotation_standalone()
253-
bench_pipeline()
254-
bench_cublas_baseline()
255-
bench_speedup_table()
256-
bench_cuda_graph_capture()
251+
bench_rotation_standalone(args.inner, args.outer)
252+
bench_pipeline(args.inner, args.outer)
253+
bench_cublas_baseline(args.inner, args.outer)
254+
bench_speedup_table(args.inner, args.outer)
255+
256+
257+
if __name__ == "__main__":
258+
main()

0 commit comments

Comments
 (0)