Skip to content

Commit 27cf6a2

Browse files
TimDettmersclaude
andcommitted
Add kbit GEMM benchmark script
Benchmarks production kernel against cuBLAS fp16/bf16 baseline across LLM-typical shapes. Measures TFLOPS, effective GB/s, and speedup ratio. Initial results on RTX 4090 with K=4, TILE_M=16: - 1.56x faster than cuBLAS for (1, 4096, 11008) — memory-bound regime - cuBLAS faster for square/compute-bound cases — expected, since current tile is small (TILE_M=16) and only uses 2 N-blocks per warp Next optimization targets: multi-M-block tiling, larger TILE_N, and better C output coalescing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b64bb91 commit 27cf6a2

File tree

1 file changed

+192
-0
lines changed

1 file changed

+192
-0
lines changed

benchmarks/bench_kbit_gemm.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Benchmark for kbit GEMM kernel.
2+
3+
Measures throughput (TFLOPS) and effective memory bandwidth (GB/s) for:
4+
1. kbit_gemm_prod (production kernel, fp16 and bf16)
5+
2. cuBLAS fp16 GEMM (baseline)
6+
3. Standalone dequant + cuBLAS (simulated fused baseline)
7+
"""
8+
9+
import argparse
10+
import sys
11+
import time
12+
13+
import torch
14+
15+
# Ensure bitsandbytes is importable from the worktree
16+
sys.path.insert(0, ".")
17+
import bitsandbytes # noqa: E402
18+
from bitsandbytes import _ops # noqa: E402, F401
19+
from scipy.stats import norm # noqa: E402
20+
21+
BLOCKSIZE = 32
22+
23+
24+
def create_normal_float_codebook(k: int) -> torch.Tensor:
25+
n_levels = 1 << k
26+
quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels)
27+
values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32)
28+
values = values / values.abs().max()
29+
return values
30+
31+
32+
def quantize_kbit_ref(A, codebook, blocksize=BLOCKSIZE):
33+
A_flat = A.float().reshape(-1)
34+
n = A_flat.numel()
35+
pad = (blocksize - n % blocksize) % blocksize
36+
if pad > 0:
37+
A_flat = torch.nn.functional.pad(A_flat, (0, pad))
38+
n_padded = A_flat.numel()
39+
num_blocks = n_padded // blocksize
40+
blocks = A_flat.reshape(num_blocks, blocksize)
41+
absmax = blocks.abs().max(dim=1).values
42+
absmax_safe = absmax.clamp(min=1e-8)
43+
normalized = blocks / absmax_safe.unsqueeze(1)
44+
cb = codebook.float().unsqueeze(0).unsqueeze(0)
45+
norm_exp = normalized.unsqueeze(2)
46+
distances = (norm_exp - cb).abs()
47+
indices = distances.argmin(dim=2).to(torch.uint8)
48+
indices = indices.reshape(-1)[:n]
49+
return indices, absmax
50+
51+
52+
def pack_kbit_ref(indices, k, blocksize=BLOCKSIZE):
53+
n = indices.numel()
54+
pad = (blocksize - n % blocksize) % blocksize
55+
if pad > 0:
56+
indices = torch.nn.functional.pad(indices.int(), (0, pad))
57+
n_padded = indices.numel()
58+
num_blocks = n_padded // blocksize
59+
blocks = indices.int().reshape(num_blocks, blocksize)
60+
packed_words = []
61+
for b in range(num_blocks):
62+
for bit in range(k):
63+
word = 0
64+
for i in range(blocksize):
65+
word |= ((int(blocks[b, i]) >> bit) & 1) << i
66+
if word >= (1 << 31):
67+
word -= 1 << 32
68+
packed_words.append(word)
69+
return torch.tensor(packed_words, dtype=torch.int32)
70+
71+
72+
def prepare_weights(K_dim, N, k):
73+
"""Quantize and repack random weights using CUDA kernels. Returns (packed_tiled, absmax_tiled, codebook)."""
74+
codebook = create_normal_float_codebook(k)
75+
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
76+
# Use CUDA quantize kernel (fast)
77+
packed_flat, absmax = torch.ops.bitsandbytes.quantize_kbit(W.reshape(-1), codebook.cuda(), k)
78+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
79+
packed_flat, absmax.cuda(), K_dim, N, k
80+
)
81+
return packed_tiled, absmax_tiled, codebook.cuda(), W
82+
83+
84+
def bench_kbit_gemm(M, K_dim, N, k, k_chunks, dtype, packed_tiled, absmax_tiled, codebook,
85+
warmup=10, iters=100):
86+
"""Benchmark the production kbit GEMM kernel."""
87+
A = torch.randn(M, K_dim, dtype=dtype, device="cuda")
88+
89+
# Warmup
90+
for _ in range(warmup):
91+
torch.ops.bitsandbytes.kbit_gemm_prod(A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, k_chunks)
92+
torch.cuda.synchronize()
93+
94+
start = time.perf_counter()
95+
for _ in range(iters):
96+
torch.ops.bitsandbytes.kbit_gemm_prod(A, packed_tiled, absmax_tiled, codebook, K_dim, N, k, k_chunks)
97+
torch.cuda.synchronize()
98+
elapsed = time.perf_counter() - start
99+
100+
return elapsed / iters
101+
102+
103+
def bench_cublas(M, K_dim, N, dtype, W_fp16, warmup=10, iters=100):
104+
"""Benchmark cuBLAS fp16 GEMM as baseline."""
105+
A = torch.randn(M, K_dim, dtype=dtype, device="cuda")
106+
W = W_fp16.to(dtype).cuda()
107+
108+
# Warmup
109+
for _ in range(warmup):
110+
torch.mm(A, W.T)
111+
torch.cuda.synchronize()
112+
113+
start = time.perf_counter()
114+
for _ in range(iters):
115+
torch.mm(A, W.T)
116+
torch.cuda.synchronize()
117+
elapsed = time.perf_counter() - start
118+
119+
return elapsed / iters
120+
121+
122+
def main():
123+
parser = argparse.ArgumentParser(description="Benchmark kbit GEMM kernel")
124+
parser.add_argument("--k", type=int, default=4, help="Bit width (2-5)")
125+
parser.add_argument("--dtype", choices=["fp16", "bf16"], default="fp16")
126+
parser.add_argument("--warmup", type=int, default=20)
127+
parser.add_argument("--iters", type=int, default=200)
128+
parser.add_argument("--k-chunks", type=int, default=1, help="Split-K chunks")
129+
args = parser.parse_args()
130+
131+
dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
132+
k = args.k
133+
134+
# LLM-typical shapes
135+
configs = [
136+
# (M, K_dim, N)
137+
(1, 4096, 4096),
138+
(1, 4096, 11008),
139+
(4, 4096, 4096),
140+
(4, 4096, 11008),
141+
(8, 4096, 4096),
142+
(16, 4096, 4096),
143+
(32, 4096, 4096),
144+
(64, 4096, 4096),
145+
(128, 4096, 4096),
146+
]
147+
148+
print(f"kbit GEMM Benchmark: K={k}, dtype={args.dtype}, k_chunks={args.k_chunks}")
149+
print(f"Warmup={args.warmup}, Iters={args.iters}")
150+
print()
151+
print(f"{'M':>5} {'K_dim':>6} {'N':>6} | {'kbit (us)':>10} {'kbit TFLOPS':>12} {'kbit GB/s':>10} | "
152+
f"{'cuBLAS (us)':>12} {'cuBLAS TFLOPS':>14} | {'Speedup':>8}")
153+
print("-" * 115)
154+
155+
for M, K_dim, N in configs:
156+
# Pad N to multiple of 128 if needed
157+
N_padded = ((N + 127) // 128) * 128
158+
159+
# Prepare weights
160+
packed_tiled, absmax_tiled, codebook, W = prepare_weights(K_dim, N_padded, k)
161+
162+
# Benchmark kbit GEMM
163+
t_kbit = bench_kbit_gemm(M, K_dim, N_padded, k, args.k_chunks, dtype,
164+
packed_tiled, absmax_tiled, codebook,
165+
warmup=args.warmup, iters=args.iters)
166+
167+
# Benchmark cuBLAS
168+
t_cublas = bench_cublas(M, K_dim, N_padded, dtype, W.half(),
169+
warmup=args.warmup, iters=args.iters)
170+
171+
# Compute metrics
172+
flops = 2 * M * K_dim * N_padded
173+
tflops_kbit = flops / t_kbit / 1e12
174+
tflops_cublas = flops / t_cublas / 1e12
175+
176+
# Effective bandwidth for kbit: A (fp16) + B (compressed) + C (fp16)
177+
a_bytes = M * K_dim * 2
178+
b_bytes = N_padded * K_dim * k / 8 + N_padded * (K_dim // 32) # packed + absmax
179+
c_bytes = M * N_padded * 2
180+
total_bytes = a_bytes + b_bytes + c_bytes
181+
gbps_kbit = total_bytes / t_kbit / 1e9
182+
183+
speedup = t_cublas / t_kbit
184+
185+
print(f"{M:5d} {K_dim:6d} {N_padded:6d} | {t_kbit*1e6:10.1f} {tflops_kbit:12.3f} {gbps_kbit:10.1f} | "
186+
f"{t_cublas*1e6:12.1f} {tflops_cublas:14.3f} | {speedup:8.2f}x")
187+
188+
print()
189+
190+
191+
if __name__ == "__main__":
192+
main()

0 commit comments

Comments
 (0)