Skip to content

Commit b7e8407

Browse files
TimDettmersclaude
andcommitted
Pass cudaStream_t through all kbit kernel launchers for CUDA graph support
All kbit kernels (quantize, repack, MMA GEMM, grouped GEMM, scalar GEMV) previously launched on CUDA stream 0, preventing CUDA graph capture. Now every launcher accepts a cudaStream_t parameter passed from Python via _get_tensor_stream(), matching the pattern used by legacy bitsandbytes kernels. This enables CUDA graph replay benchmarking and is required for any downstream CUDA graph integration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9d11e85 commit b7e8407

File tree

6 files changed

+486
-123
lines changed

6 files changed

+486
-123
lines changed

benchmarks/bench_cuda_events.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
"""CUDA event benchmark for kbit kernels — measures kernel-only latency.
2+
3+
Uses pre-allocated output buffers (out parameter) and CUDA events to
4+
measure just the kernel execution time, excluding allocation overhead.
5+
6+
Output: same shape/k/M grid as bench_ncu.sh for direct comparison.
7+
8+
Usage:
9+
python benchmarks/bench_cuda_events.py # all kernels
10+
python benchmarks/bench_cuda_events.py --kernel mma # MMA only
11+
python benchmarks/bench_cuda_events.py --kernel scalar # scalar GEMV only
12+
"""
13+
14+
import argparse
15+
import os
16+
import sys
17+
18+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19+
20+
import torch
21+
from bitsandbytes.functional import create_normal_float_codebook
22+
23+
WARMUP = 20
24+
ITERS = 100
25+
26+
# Same shapes as ncu_driver.py
27+
dense_shapes = [
28+
("gateup", 2048, 5120),
29+
("down", 5120, 2048),
30+
("Q", 2048, 4096),
31+
("O", 4096, 2048),
32+
("KV", 2048, 512),
33+
]
34+
35+
moe_shapes = [
36+
("moe_gu", 2048, 512),
37+
("moe_dn", 512, 2048),
38+
]
39+
40+
k_bits_list = [2, 3, 4, 5]
41+
NUM_EXPERTS = 8
42+
43+
44+
def bench_kernel(fn, warmup=WARMUP, iters=ITERS):
45+
"""Time a kernel using CUDA graph replay + events.
46+
47+
Captures the kernel into a CUDA graph, then replays it to measure
48+
kernel-only latency without Python dispatch or launch overhead.
49+
"""
50+
# Warmup (eager, to JIT compile etc.)
51+
for _ in range(warmup):
52+
fn()
53+
torch.cuda.synchronize()
54+
55+
# Capture into CUDA graph
56+
graph = torch.cuda.CUDAGraph()
57+
with torch.cuda.graph(graph):
58+
fn()
59+
torch.cuda.synchronize()
60+
61+
# Time graph replays
62+
start = torch.cuda.Event(enable_timing=True)
63+
end = torch.cuda.Event(enable_timing=True)
64+
65+
# Warmup the graph replay
66+
for _ in range(warmup):
67+
graph.replay()
68+
torch.cuda.synchronize()
69+
70+
start.record()
71+
for _ in range(iters):
72+
graph.replay()
73+
end.record()
74+
torch.cuda.synchronize()
75+
76+
total_ms = start.elapsed_time(end)
77+
return (total_ms / iters) * 1000.0 # convert ms -> us
78+
79+
80+
def prepare_dense_data(device):
81+
"""Pre-quantize all dense shapes for all k values."""
82+
data = {}
83+
for name, K_dim, N in dense_shapes:
84+
for k in k_bits_list:
85+
codebook = create_normal_float_codebook(k, device=device)
86+
W = torch.randn(K_dim * N, device=device, dtype=torch.float32)
87+
packed_flat, absmax_flat = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
88+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
89+
packed_flat, absmax_flat, K_dim, N, k
90+
)
91+
data[(name, k)] = (K_dim, N, packed_flat, absmax_flat, packed_tiled, absmax_tiled, codebook)
92+
return data
93+
94+
95+
def prepare_moe_data(device):
96+
"""Pre-quantize MoE expert weights."""
97+
data = {}
98+
for name, K_dim, N in moe_shapes:
99+
for k in k_bits_list:
100+
codebook = create_normal_float_codebook(k, device=device)
101+
packed_list, absmax_list = [], []
102+
for _ in range(NUM_EXPERTS):
103+
W = torch.randn(K_dim * N, device=device, dtype=torch.float32)
104+
pf, af = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
105+
pt, at = torch.ops.bitsandbytes.repack_kbit(pf, af, K_dim, N, k)
106+
packed_list.append(pt)
107+
absmax_list.append(at)
108+
B_packed_all = torch.cat(packed_list, dim=0)
109+
B_absmax_all = torch.cat(absmax_list, dim=0)
110+
data[(name, k)] = (K_dim, N, B_packed_all, B_absmax_all, codebook)
111+
return data
112+
113+
114+
def bench_mma(data, m_vals, device):
115+
"""Benchmark MMA GEMM kernel with out parameter."""
116+
print("\n=== MMA kernel (CUDA events) ===")
117+
print(f"{'shape':<8} {'k':>2} {'M':>2} {'avg_us':>10}")
118+
print("---")
119+
120+
for name, K_dim, N in dense_shapes:
121+
for k in k_bits_list:
122+
K_dim, N, _, _, packed_tiled, absmax_tiled, codebook = data[(name, k)]
123+
for M in m_vals:
124+
A = torch.randn(M, K_dim, dtype=torch.float16, device=device)
125+
out = torch.empty(M, N, dtype=torch.float16, device=device)
126+
127+
# Allocate workspace and tile_counters for the _ variant
128+
C_workspace = torch.zeros(M, N, dtype=torch.float32, device=device)
129+
# Upper bound on tile count
130+
TILE_M = 16 * max(1, min(4, (M + 15) // 16))
131+
TILE_N = 64 if M <= 16 and N % 64 == 0 else 128
132+
m_tiles = (M + TILE_M - 1) // TILE_M
133+
n_tiles = N // TILE_N
134+
tile_counters = torch.zeros(m_tiles * n_tiles, dtype=torch.int32, device=device)
135+
136+
fn = lambda: torch.ops.bitsandbytes.kbit_gemm_prod_(
137+
A, packed_tiled, absmax_tiled, codebook,
138+
K_dim, N, k, 1, out, C_workspace, tile_counters,
139+
)
140+
avg_us = bench_kernel(fn)
141+
print(f"{name:<8} {k:>2} {M:>2} {avg_us:>10.2f}")
142+
143+
144+
def bench_scalar(data, m_vals, device):
145+
"""Benchmark scalar GEMV kernel with out parameter (tiled layout)."""
146+
m_vals = [m for m in m_vals if m <= 4]
147+
if not m_vals:
148+
print("\n=== Scalar GEMV (CUDA events) ===\n(no M<=4 values)")
149+
return
150+
151+
print(f"\n=== Scalar GEMV M<={max(m_vals)} (CUDA events) ===")
152+
print(f"{'shape':<8} {'k':>2} {'M':>2} {'avg_us':>10}")
153+
print("---")
154+
155+
for name, K_dim, N in dense_shapes:
156+
for k in k_bits_list:
157+
K_dim, N, _, _, packed_tiled, absmax_tiled, codebook = data[(name, k)]
158+
for M in m_vals:
159+
A = torch.randn(M, K_dim, dtype=torch.float16, device=device)
160+
out = torch.empty(M, N, dtype=torch.float16, device=device)
161+
162+
fn = lambda: torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
163+
A, packed_tiled, absmax_tiled, codebook,
164+
K_dim, N, k, out,
165+
)
166+
avg_us = bench_kernel(fn)
167+
print(f"{name:<8} {k:>2} {M:>2} {avg_us:>10.2f}")
168+
169+
170+
def bench_grouped(moe_data, m_vals, device):
171+
"""Benchmark grouped MMA kernel (MoE)."""
172+
print(f"\n=== Grouped MMA ({NUM_EXPERTS} experts, CUDA events) ===")
173+
print(f"{'shape':<8} {'k':>2} {'M':>2} {'avg_us':>10}")
174+
print("---")
175+
176+
for name, K_dim, N in moe_shapes:
177+
for k in k_bits_list:
178+
K_dim, N, B_packed_all, B_absmax_all, codebook = moe_data[(name, k)]
179+
for M in m_vals:
180+
total_tokens = M * NUM_EXPERTS
181+
A_concat = torch.randn(total_tokens, K_dim, dtype=torch.float16, device=device)
182+
offsets = list(range(0, total_tokens + 1, M))
183+
expert_offsets = torch.tensor(offsets, dtype=torch.int32, device=device)
184+
185+
# Grouped GEMM doesn't have an _ variant yet — use the allocating version
186+
fn = lambda: torch.ops.bitsandbytes.kbit_grouped_gemm(
187+
A_concat, B_packed_all, B_absmax_all, codebook,
188+
expert_offsets, K_dim, N, k, NUM_EXPERTS, M,
189+
)
190+
avg_us = bench_kernel(fn)
191+
print(f"{name:<8} {k:>2} {M:>2} {avg_us:>10.2f}")
192+
193+
194+
def main():
195+
parser = argparse.ArgumentParser(description="CUDA event kernel benchmark")
196+
parser.add_argument("--kernel", choices=["mma", "scalar", "grouped", "all"], default="all")
197+
parser.add_argument("--m-vals", default="1,2,3,4,5,6,7,8", help="Comma-separated M values")
198+
args = parser.parse_args()
199+
200+
m_vals = [int(x) for x in args.m_vals.split(",")]
201+
device = torch.device("cuda")
202+
203+
print(f"GPU: {torch.cuda.get_device_name(0)}")
204+
print(f"Warmup: {WARMUP}, Iterations: {ITERS}")
205+
print(f"M values: {m_vals}")
206+
207+
dense_data = prepare_dense_data(device)
208+
209+
if args.kernel in ("mma", "all"):
210+
bench_mma(dense_data, m_vals, device)
211+
212+
if args.kernel in ("scalar", "all"):
213+
bench_scalar(dense_data, m_vals, device)
214+
215+
if args.kernel in ("grouped", "all"):
216+
moe_data = prepare_moe_data(device)
217+
bench_grouped(moe_data, m_vals, device)
218+
219+
220+
if __name__ == "__main__":
221+
main()

benchmarks/bench_tiled_vs_flat.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Benchmark tiled vs flat scalar GEMV with pre-allocated output buffers.
2+
3+
Measures kernel-only time by pre-allocating all buffers before the timing loop.
4+
No allocations inside the measured region — fair comparison between flat and tiled.
5+
6+
Usage:
7+
python benchmarks/bench_tiled_vs_flat.py
8+
python benchmarks/bench_tiled_vs_flat.py --ncu # NCU mode (single iteration)
9+
"""
10+
11+
import argparse
12+
import os
13+
import sys
14+
15+
for p in [".", ".."]:
16+
if os.path.isfile(os.path.join(p, "bitsandbytes", "__init__.py")):
17+
sys.path.insert(0, os.path.abspath(p))
18+
break
19+
20+
import torch
21+
22+
from bitsandbytes.functional import create_normal_float_codebook
23+
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument("--ncu", action="store_true", help="NCU mode: single iteration, no timing")
26+
parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations")
27+
parser.add_argument("--iters", type=int, default=100, help="Timed iterations")
28+
args = parser.parse_args()
29+
30+
SHAPES = [
31+
("gateup", 2048, 5120),
32+
("down", 5120, 2048),
33+
("Q", 2048, 4096),
34+
("KV", 2048, 512),
35+
]
36+
K_VALUES = [2, 3, 4, 5]
37+
M_VALUES = [1, 2, 4]
38+
39+
print(f"{'shape':<8} {'K_dim':>5} {'N':>5} {'k':>2} {'M':>2} {'flat_us':>8} {'tiled_us':>8} {'diff%':>7}")
40+
print("-" * 60)
41+
42+
for name, K_dim, N in SHAPES:
43+
for k in K_VALUES:
44+
codebook = create_normal_float_codebook(k).cuda()
45+
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
46+
47+
# Quantize and repack
48+
packed_flat, absmax_flat = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
49+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
50+
packed_flat, absmax_flat, K_dim, N, k
51+
)
52+
53+
for M in M_VALUES:
54+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
55+
56+
# Pre-allocate output buffers
57+
out_flat = torch.empty(M, N, dtype=torch.float16, device="cuda")
58+
out_tiled = torch.empty(M, N, dtype=torch.float16, device="cuda")
59+
60+
if args.ncu:
61+
# NCU mode: single call each, profiler captures kernel time
62+
torch.ops.bitsandbytes.kbit_scalar_gemv.out(
63+
A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat
64+
)
65+
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
66+
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
67+
)
68+
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {'ncu':>8} {'ncu':>8} {'ncu':>7}")
69+
continue
70+
71+
# CUDA events timing
72+
start = torch.cuda.Event(enable_timing=True)
73+
end = torch.cuda.Event(enable_timing=True)
74+
75+
# --- Flat ---
76+
for _ in range(args.warmup):
77+
torch.ops.bitsandbytes.kbit_scalar_gemv.out(
78+
A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat
79+
)
80+
torch.cuda.synchronize()
81+
82+
start.record()
83+
for _ in range(args.iters):
84+
torch.ops.bitsandbytes.kbit_scalar_gemv.out(
85+
A, packed_flat, absmax_flat, codebook, K_dim, N, k, out_flat
86+
)
87+
end.record()
88+
torch.cuda.synchronize()
89+
flat_us = start.elapsed_time(end) * 1000 / args.iters # ms -> us
90+
91+
# --- Tiled ---
92+
for _ in range(args.warmup):
93+
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
94+
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
95+
)
96+
torch.cuda.synchronize()
97+
98+
start.record()
99+
for _ in range(args.iters):
100+
torch.ops.bitsandbytes.kbit_scalar_gemv_tiled_(
101+
A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, out_tiled
102+
)
103+
end.record()
104+
torch.cuda.synchronize()
105+
tiled_us = start.elapsed_time(end) * 1000 / args.iters
106+
107+
diff_pct = (tiled_us - flat_us) / flat_us * 100
108+
print(f"{name:<8} {K_dim:>5} {N:>5} {k:>2} {M:>2} {flat_us:>8.1f} {tiled_us:>8.1f} {diff_pct:>+7.1f}%")
109+
110+
# Correctness check (once per shape/k)
111+
assert torch.equal(out_flat, out_tiled) or torch.allclose(out_flat, out_tiled, rtol=0.05, atol=0.1), (
112+
f"MISMATCH {name} k={k}: max diff = {(out_flat - out_tiled).abs().max().item()}"
113+
)

bitsandbytes/backends/cuda/ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def _(A: torch.Tensor, codebook: torch.Tensor, k: int) -> tuple[torch.Tensor, to
799799
get_ptr(absmax),
800800
get_ptr(packed),
801801
ct.c_int(n),
802+
_get_tensor_stream(A),
802803
)
803804

804805
return packed, absmax
@@ -993,6 +994,7 @@ def _(
993994
get_ptr(absmax_tiled),
994995
ct.c_int(K_dim),
995996
ct.c_int(N),
997+
_get_tensor_stream(packed_flat),
996998
)
997999

9981000
return packed_tiled, absmax_tiled
@@ -1036,6 +1038,7 @@ def _kbit_gemm_prod_impl(A, B_packed, B_absmax, codebook, K_dim, N, k, k_chunks,
10361038
ct.c_int(K_dim),
10371039
ct.c_int(N),
10381040
ct.c_int(k_chunks),
1041+
_get_tensor_stream(A),
10391042
)
10401043

10411044

@@ -1143,6 +1146,7 @@ def _kbit_grouped_gemm_impl(
11431146
ct.c_int(N),
11441147
ct.c_int(num_experts),
11451148
ct.c_int(max_M),
1149+
_get_tensor_stream(A_concat),
11461150
)
11471151

11481152

@@ -1259,6 +1263,7 @@ def _kbit_scalar_gemv_impl(
12591263
ct.c_int(M),
12601264
ct.c_int(K_dim),
12611265
ct.c_int(N),
1266+
_get_tensor_stream(A),
12621267
)
12631268

12641269

@@ -1330,6 +1335,7 @@ def _(
13301335
ct.c_int(M),
13311336
ct.c_int(K_dim),
13321337
ct.c_int(N),
1338+
_get_tensor_stream(A),
13331339
)
13341340
return out
13351341

@@ -1366,5 +1372,6 @@ def _(
13661372
ct.c_int(M),
13671373
ct.c_int(K_dim),
13681374
ct.c_int(N),
1375+
_get_tensor_stream(A),
13691376
)
13701377
return out

0 commit comments

Comments
 (0)