Skip to content

Commit d834367

Browse files
TimDettmersclaude
andcommitted
Add CUDA graph mode, stddev, and O shape to tiled vs flat benchmark
- Add --graph flag for CUDA graph replay timing (kernel-only, no dispatch overhead) - Add --trials flag with stddev reporting across multiple trials - Add missing O (4096x2048) shape to match full Qwen 72B shape set - Results: tiled layout 5-30% slower than flat on large shapes, neutral on KV Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b7e8407 commit d834367

File tree

1 file changed

+69
-28
lines changed

1 file changed

+69
-28
lines changed

benchmarks/bench_tiled_vs_flat.py

Lines changed: 69 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,28 @@
2323

2424
parser = argparse.ArgumentParser()
2525
parser.add_argument("--ncu", action="store_true", help="NCU mode: single iteration, no timing")
26+
parser.add_argument("--graph", action="store_true", help="Use CUDA graph replay for accurate kernel timing")
2627
parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations")
27-
parser.add_argument("--iters", type=int, default=100, help="Timed iterations")
28+
parser.add_argument("--iters", type=int, default=100, help="Timed iterations per trial")
29+
parser.add_argument("--trials", type=int, default=5, help="Number of trials for stddev (graph mode)")
2830
args = parser.parse_args()
2931

3032
SHAPES = [
3133
("gateup", 2048, 5120),
3234
("down", 5120, 2048),
3335
("Q", 2048, 4096),
36+
("O", 4096, 2048),
3437
("KV", 2048, 512),
3538
]
3639
K_VALUES = [2, 3, 4, 5]
3740
M_VALUES = [1, 2, 4]
3841

39-
print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'tiled_us':>8} {'diff%':>7}")
40-
print("-" * 60)
42+
if args.graph:
43+
print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'±flat':>6} {'tiled_us':>8} {'±tiled':>6} {'diff%':>7}")
44+
print("-" * 76)
45+
else:
46+
print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'tiled_us':>8} {'diff%':>7}")
47+
print("-" * 60)
4148

4249
for name, K_dim, N in SHAPES:
4350
for k in K_VALUES:
@@ -68,44 +75,78 @@
6875
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {'ncu':>8} {'ncu':>8} {'ncu':>7}")
6976
continue
7077

71-
# CUDA events timing
7278
start = torch.cuda.Event(enable_timing=True)
7379
end = torch.cuda.Event(enable_timing=True)
7480

75-
# --- Flat ---
76-
for _ in range(args.warmup):
81+
def call_flat():
7782
torch.ops.bitsandbytes.kbit_scalar_gemv.out(
7883
A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat
7984
)
80-
torch.cuda.synchronize()
8185

82-
start.record()
83-
for _ in range(args.iters):
84-
torch.ops.bitsandbytes.kbit_scalar_gemv.out(
85-
A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat
86-
)
87-
end.record()
88-
torch.cuda.synchronize()
89-
flat_us = start.elapsed_time(end) * 1000 / args.iters # ms -> us
90-
91-
# --- Tiled ---
92-
for _ in range(args.warmup):
86+
def call_tiled():
9387
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
9488
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
9589
)
96-
torch.cuda.synchronize()
9790

98-
start.record()
99-
for _ in range(args.iters):
100-
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
101-
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
102-
)
103-
end.record()
104-
torch.cuda.synchronize()
105-
tiled_us = start.elapsed_time(end) * 1000 / args.iters
91+
if args.graph:
92+
import statistics
93+
94+
# CUDA graph replay — measures kernel-only time
95+
for fn in (call_flat, call_tiled):
96+
for _ in range(3):
97+
fn()
98+
torch.cuda.synchronize()
99+
100+
def bench_graph(fn, trials, iters):
101+
s = torch.cuda.Stream()
102+
s.wait_stream(torch.cuda.current_stream())
103+
with torch.cuda.stream(s):
104+
g = torch.cuda.CUDAGraph()
105+
with torch.cuda.graph(g, stream=s):
106+
fn()
107+
torch.cuda.synchronize()
108+
times = []
109+
for _ in range(trials):
110+
start.record()
111+
for _ in range(iters):
112+
g.replay()
113+
end.record()
114+
torch.cuda.synchronize()
115+
times.append(start.elapsed_time(end) * 1000 / iters)
116+
return statistics.mean(times), statistics.stdev(times) if len(times) > 1 else 0.0
117+
118+
flat_us, flat_std = bench_graph(call_flat, args.trials, args.iters)
119+
tiled_us, tiled_std = bench_graph(call_tiled, args.trials, args.iters)
120+
else:
121+
# CUDA events timing (includes Python dispatch overhead)
122+
for _ in range(args.warmup):
123+
call_flat()
124+
torch.cuda.synchronize()
125+
start.record()
126+
for _ in range(args.iters):
127+
call_flat()
128+
end.record()
129+
torch.cuda.synchronize()
130+
flat_us = start.elapsed_time(end) * 1000 / args.iters
131+
132+
for _ in range(args.warmup):
133+
call_tiled()
134+
torch.cuda.synchronize()
135+
start.record()
136+
for _ in range(args.iters):
137+
call_tiled()
138+
end.record()
139+
torch.cuda.synchronize()
140+
tiled_us = start.elapsed_time(end) * 1000 / args.iters
106141

107142
diff_pct = (tiled_us - flat_us) / flat_us * 100
108-
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {flat_us:>8.1f} {tiled_us:>8.1f} {diff_pct:>+7.1f}%")
143+
if args.graph:
144+
print(
145+
f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2}"
146+
f" {flat_us:>8.1f} {flat_std:>5.1f}σ {tiled_us:>8.1f} {tiled_std:>5.1f}σ {diff_pct:>+7.1f}%"
147+
)
148+
else:
149+
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {flat_us:>8.1f} {tiled_us:>8.1f} {diff_pct:>+7.1f}%")
109150

110151
# Correctness check (once per shape/k)
111152
assert torch.equal(out_flat, out_tiled) or torch.allclose(out_flat, out_tiled, rtol=0.05, atol=0.1), (

0 commit comments

Comments
 (0)