Skip to content

Commit daa2f12

Browse files
TimDettmersclaude
andcommitted
Add MoE analysis benchmarks, update grouped GEMM baseline to bmm
- bench_grouped_gemm.py: replace sequential cuBLAS baseline with torch.bmm (batched GEMM) for fair single-launch comparison - bench_moe_e2e.py: end-to-end MoE layer timing with realistic expert routing distributions for Qwen3 and GLM-4.7 - bench_gemv_analysis.py: dequant+bmm vs theoretical scalar GEMV - bench_gemv_theoretical.py: roofline model for scalar kbit kernel showing 2-5x theoretical speedup over bmm at all batch sizes - progress.md: consolidate into complete self-contained dev record - Remove optimization.md (superseded by progress.md sections 22-24) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8e527ff commit daa2f12

File tree

6 files changed

+1867
-2186
lines changed

6 files changed

+1867
-2186
lines changed

benchmarks/bench_gemv_analysis.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Analysis: small-batch strategies for kbit MoE GEMM.
2+
3+
Benchmarks three approaches for the batch=1 to batch=8 regime:
4+
1. kbit grouped GEMM (current kernel)
5+
2. cuBLAS bmm (fp16 baseline)
6+
3. Dequant-to-fp16 + cuBLAS bmm (hybrid approach)
7+
8+
Also estimates theoretical performance of a specialized kbit GEMV kernel.
9+
"""
10+
11+
import sys
12+
import time
13+
14+
import torch
15+
16+
sys.path.insert(0, ".")
17+
import bitsandbytes # noqa: E402
18+
from bitsandbytes import _ops # noqa: E402, F401
19+
from scipy.stats import norm # noqa: E402
20+
21+
22+
def create_normal_float_codebook(k: int) -> torch.Tensor:
23+
n_levels = 1 << k
24+
quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels)
25+
values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32)
26+
values = values / values.abs().max()
27+
return values
28+
29+
30+
def bench(fn, warmup=30, iters=500):
31+
for _ in range(warmup):
32+
fn()
33+
torch.cuda.synchronize()
34+
start = time.perf_counter()
35+
for _ in range(iters):
36+
fn()
37+
torch.cuda.synchronize()
38+
return (time.perf_counter() - start) / iters
39+
40+
41+
def main():
42+
k = 4
43+
codebook = create_normal_float_codebook(k).cuda()
44+
45+
# Qwen3-Coder-Next MoE shapes
46+
shapes = [
47+
(2048, 512, "gate/up"),
48+
(512, 2048, "down"),
49+
]
50+
51+
print(f"Small-Batch MoE Strategy Analysis (K={k}, RTX 4090)")
52+
print(f"Model: Qwen3-Coder-Next (512 experts, top-8)")
53+
print()
54+
55+
for K_dim, N, layer_name in shapes:
56+
N_padded = ((N + 127) // 128) * 128
57+
print(f"{'='*90}")
58+
print(f" Layer: {layer_name} ({K_dim} x {N_padded})")
59+
print(f"{'='*90}")
60+
print()
61+
62+
hdr = (f"{'#exp':>4} {'M':>2} | {'kbit grp':>8} {'bmm fp16':>8} "
63+
f"{'dq+bmm':>8} | {'grp/bmm':>8} {'dq+bmm/bmm':>11}")
64+
print(hdr)
65+
print("-" * len(hdr))
66+
67+
for num_experts in [1, 4, 8, 16, 32, 64]:
68+
M_per_expert = 1
69+
70+
# --- Prepare kbit weights ---
71+
packed_list = []
72+
absmax_list = []
73+
# Keep flat packed + absmax for dequant path
74+
flat_packed_list = []
75+
flat_absmax_list = []
76+
W_list = []
77+
78+
for _ in range(num_experts):
79+
W = torch.randn(N_padded, K_dim, dtype=torch.float16, device="cuda")
80+
packed_flat, absmax_flat = torch.ops.bitsandbytes.quantize_kbit(
81+
W.reshape(-1), codebook, k
82+
)
83+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
84+
packed_flat, absmax_flat.cuda(), K_dim, N_padded, k
85+
)
86+
packed_list.append(packed_tiled)
87+
absmax_list.append(absmax_tiled)
88+
flat_packed_list.append(packed_flat)
89+
flat_absmax_list.append(absmax_flat.cuda())
90+
W_list.append(W)
91+
92+
B_packed_all = torch.cat(packed_list, dim=0)
93+
B_absmax_all = torch.cat(absmax_list, dim=0)
94+
95+
# --- Build activations ---
96+
A_list = []
97+
offsets = [0]
98+
for i in range(num_experts):
99+
A_i = torch.randn(M_per_expert, K_dim, dtype=torch.float16, device="cuda")
100+
A_list.append(A_i)
101+
offsets.append(offsets[-1] + M_per_expert)
102+
103+
A_concat = torch.cat(A_list, dim=0)
104+
expert_offsets = torch.tensor(offsets, dtype=torch.int32, device="cuda")
105+
106+
# --- 1. kbit grouped GEMM ---
107+
t_grouped = bench(lambda: torch.ops.bitsandbytes.kbit_grouped_gemm(
108+
A_concat, B_packed_all, B_absmax_all, codebook,
109+
expert_offsets, K_dim, N_padded, k, num_experts,
110+
))
111+
112+
# --- 2. cuBLAS bmm (fp16 baseline) ---
113+
A_batched = torch.stack(A_list, dim=0)
114+
W_batched_T = torch.stack([W.T for W in W_list], dim=0)
115+
116+
t_bmm = bench(lambda: torch.bmm(A_batched, W_batched_T))
117+
118+
# --- 3. Dequant + bmm ---
119+
# Pre-allocate output buffer for dequantized weights
120+
n_elements = N_padded * K_dim
121+
W_deq_flat = [torch.empty(n_elements, dtype=torch.float16, device="cuda")
122+
for _ in range(num_experts)]
123+
124+
n_elements = N_padded * K_dim
125+
126+
def dequant_then_bmm():
127+
# Dequant each expert's weights to fp16
128+
deq_list = []
129+
for i in range(num_experts):
130+
deq = torch.ops.bitsandbytes.dequantize_kbit(
131+
flat_packed_list[i], codebook, flat_absmax_list[i],
132+
k, n_elements, torch.float16,
133+
)
134+
deq_list.append(deq.view(N_padded, K_dim).T)
135+
# Stack into batched tensor and run bmm
136+
W_batch = torch.stack(deq_list, dim=0)
137+
return torch.bmm(A_batched, W_batch)
138+
139+
t_dq_bmm = bench(dequant_then_bmm)
140+
141+
# Also time just the dequant part
142+
def just_dequant():
143+
for i in range(num_experts):
144+
torch.ops.bitsandbytes.dequantize_kbit(
145+
flat_packed_list[i], codebook, flat_absmax_list[i],
146+
k, n_elements, torch.float16,
147+
)
148+
149+
t_dq_only = bench(just_dequant)
150+
151+
ratio_grp = t_grouped / t_bmm
152+
ratio_dq = t_dq_bmm / t_bmm
153+
154+
print(f"{num_experts:4d} {M_per_expert:2d} | {t_grouped*1e6:7.0f}us "
155+
f"{t_bmm*1e6:7.0f}us {t_dq_bmm*1e6:7.0f}us | "
156+
f"{ratio_grp:7.2f}x {ratio_dq:10.2f}x"
157+
f" (dq alone: {t_dq_only*1e6:.0f}us)")
158+
159+
print()
160+
161+
# Theoretical GEMV analysis
162+
print(f"\n{'='*90}")
163+
print(" Theoretical: specialized kbit GEMV for batch=1")
164+
print(f"{'='*90}")
165+
print()
166+
print(" For M=1 (one token per expert), the GEMM kernel wastes 93.75% of tensor")
167+
print(" core work (TILE_M=16 but only 1 row has data). A scalar GEMV avoids this.")
168+
print()
169+
170+
for K_dim, N, name in shapes:
171+
N_padded = ((N + 127) // 128) * 128
172+
kbit_bytes = num_experts * (N_padded * K_dim * k // 8 + N_padded * (K_dim // 32))
173+
fp16_bytes = num_experts * N_padded * K_dim * 2
174+
175+
# RTX 4090 specs
176+
l2_bw = 2000 # GB/s effective L2 bandwidth
177+
dram_bw = 900 # GB/s
178+
sms = 128
179+
cores_per_sm = 128
180+
clock_ghz = 2.52
181+
182+
# For 8 experts:
183+
ne = 8
184+
kbit_data = ne * (N_padded * K_dim * k / 8 + N_padded * (K_dim // 32))
185+
fp16_data = ne * N_padded * K_dim * 2
186+
187+
# Bandwidth time (L2-resident for 8 experts)
188+
t_bw_kbit = kbit_data / (l2_bw * 1e9) * 1e6 # us
189+
t_bw_fp16 = fp16_data / (l2_bw * 1e9) * 1e6
190+
191+
# Instruction time for kbit GEMV
192+
# Per element: ~14 integer/fp ops for dequant + FMA
193+
# Total elements: ne * N * K_dim
194+
total_elements = ne * N_padded * K_dim
195+
ops_per_element = 14
196+
total_ops = total_elements * ops_per_element
197+
# INT32 throughput: sms * cores * clock = ~41 TOPS
198+
int_throughput = sms * cores_per_sm * clock_ghz * 1e9
199+
t_compute = total_ops / int_throughput * 1e6 # us
200+
201+
# Estimated total (max of bandwidth and compute, with some overhead)
202+
t_estimated = max(t_bw_kbit, t_compute) * 1.5 # 1.5x for overhead
203+
204+
print(f" {name} ({K_dim}x{N_padded}), 8 experts, M=1:")
205+
print(f" kbit data: {kbit_data/1e6:.2f} MB → L2 read: {t_bw_kbit:.1f} us")
206+
print(f" fp16 data: {fp16_data/1e6:.1f} MB → L2 read: {t_bw_fp16:.1f} us")
207+
print(f" Compute (dequant+FMA): {total_elements/1e6:.1f}M elements × {ops_per_element} ops = {t_compute:.1f} us")
208+
print(f" Estimated GEMV time: {t_estimated:.0f} us")
209+
print(f" vs cuBLAS bmm ~17 us → {17/t_estimated:.1f}x")
210+
print()
211+
212+
213+
if __name__ == "__main__":
214+
main()

0 commit comments

Comments
 (0)