Skip to content

Commit 8e527ff

Browse files
TimDettmersclaude
andcommitted
Add grouped expert GEMM kernel for MoE inference
Batches multiple MoE expert GEMM invocations into a single kernel launch, solving the low SM utilization problem for individual expert shapes (3-12% → 100%). Reuses the v1 production inner loop unchanged. Kernel: persistent work loop with binary-search work distribution across (expert_id, m_tile, n_tile). C++ launcher reads expert_offsets from device, computes work_offsets internally to avoid Python-side GPU→CPU sync overhead. Benchmark results (K=4, RTX 4090, vs sequential cuBLAS): Qwen3 gate/up 8exp M=4: 2.0x Qwen3 gate/up 32exp M=1: 5.8x Qwen3 down 8exp M=1: 2.0x GLM4.7 routed 8exp M=1: 1.5x 10 new tests, 195 existing tests unaffected. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0d77a61 commit 8e527ff

File tree

6 files changed

+1015
-0
lines changed

6 files changed

+1015
-0
lines changed

benchmarks/bench_grouped_gemm.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
"""Benchmark for kbit grouped expert GEMM kernel.
2+
3+
Compares:
4+
1. Grouped GEMM (one kernel launch for all experts)
5+
2. Individual kbit_gemm_prod calls (one per expert, sequential)
6+
3. cuBLAS fp16 GEMM (one per expert, sequential)
7+
8+
Simulates MoE inference with varying batch sizes and expert counts.
9+
"""
10+
11+
import argparse
12+
import sys
13+
import time
14+
15+
import torch
16+
17+
sys.path.insert(0, ".")
18+
import bitsandbytes # noqa: E402
19+
from bitsandbytes import _ops # noqa: E402, F401
20+
from scipy.stats import norm # noqa: E402
21+
22+
BLOCKSIZE = 32
23+
24+
25+
def create_normal_float_codebook(k: int) -> torch.Tensor:
26+
n_levels = 1 << k
27+
quantiles = torch.linspace(0.5 / n_levels, 1.0 - 0.5 / n_levels, n_levels)
28+
values = torch.tensor(norm.ppf(quantiles.numpy()), dtype=torch.float32)
29+
values = values / values.abs().max()
30+
return values
31+
32+
33+
def prepare_expert_weights(K_dim, N, k, num_experts):
34+
codebook = create_normal_float_codebook(k).cuda()
35+
packed_list = []
36+
absmax_list = []
37+
W_list = []
38+
39+
for _ in range(num_experts):
40+
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
41+
packed_flat, absmax = torch.ops.bitsandbytes.quantize_kbit(
42+
W.reshape(-1), codebook, k
43+
)
44+
packed_tiled, absmax_tiled = torch.ops.bitsandbytes.repack_kbit(
45+
packed_flat, absmax.cuda(), K_dim, N, k
46+
)
47+
packed_list.append(packed_tiled)
48+
absmax_list.append(absmax_tiled)
49+
W_list.append(W)
50+
51+
B_packed_all = torch.cat(packed_list, dim=0)
52+
B_absmax_all = torch.cat(absmax_list, dim=0)
53+
return B_packed_all, B_absmax_all, codebook, W_list, packed_list, absmax_list
54+
55+
56+
def bench_grouped_gemm(A_concat, B_packed_all, B_absmax_all, codebook,
57+
expert_offsets, K_dim, N, k, num_experts,
58+
warmup=20, iters=200):
59+
for _ in range(warmup):
60+
torch.ops.bitsandbytes.kbit_grouped_gemm(
61+
A_concat, B_packed_all, B_absmax_all, codebook,
62+
expert_offsets, K_dim, N, k, num_experts,
63+
)
64+
torch.cuda.synchronize()
65+
66+
start = time.perf_counter()
67+
for _ in range(iters):
68+
torch.ops.bitsandbytes.kbit_grouped_gemm(
69+
A_concat, B_packed_all, B_absmax_all, codebook,
70+
expert_offsets, K_dim, N, k, num_experts,
71+
)
72+
torch.cuda.synchronize()
73+
return (time.perf_counter() - start) / iters
74+
75+
76+
def bench_individual_kbit(A_list, packed_list, absmax_list, codebook,
77+
K_dim, N, k, warmup=20, iters=200):
78+
for _ in range(warmup):
79+
for i in range(len(A_list)):
80+
torch.ops.bitsandbytes.kbit_gemm_prod(
81+
A_list[i], packed_list[i], absmax_list[i], codebook,
82+
K_dim, N, k, 1,
83+
)
84+
torch.cuda.synchronize()
85+
86+
start = time.perf_counter()
87+
for _ in range(iters):
88+
for i in range(len(A_list)):
89+
torch.ops.bitsandbytes.kbit_gemm_prod(
90+
A_list[i], packed_list[i], absmax_list[i], codebook,
91+
K_dim, N, k, 1,
92+
)
93+
torch.cuda.synchronize()
94+
return (time.perf_counter() - start) / iters
95+
96+
97+
def bench_individual_cublas(A_list, W_list, warmup=20, iters=200):
98+
for _ in range(warmup):
99+
for i in range(len(A_list)):
100+
torch.mm(A_list[i], W_list[i].T)
101+
torch.cuda.synchronize()
102+
103+
start = time.perf_counter()
104+
for _ in range(iters):
105+
for i in range(len(A_list)):
106+
torch.mm(A_list[i], W_list[i].T)
107+
torch.cuda.synchronize()
108+
return (time.perf_counter() - start) / iters
109+
110+
111+
def main():
112+
parser = argparse.ArgumentParser(description="Benchmark grouped expert GEMM")
113+
parser.add_argument("--k", type=int, default=4, help="Bit width (2-5)")
114+
parser.add_argument("--warmup", type=int, default=20)
115+
parser.add_argument("--iters", type=int, default=200)
116+
args = parser.parse_args()
117+
118+
k = args.k
119+
120+
# MoE scenarios
121+
configs = [
122+
# (K_dim, N, num_experts, M_per_expert, description)
123+
# Qwen3-Coder-Next gate/up expert
124+
(2048, 512, 8, 1, "Qwen3 gate/up 8exp M=1"),
125+
(2048, 512, 8, 4, "Qwen3 gate/up 8exp M=4"),
126+
(2048, 512, 8, 8, "Qwen3 gate/up 8exp M=8"),
127+
(2048, 512, 32, 1, "Qwen3 gate/up 32exp M=1"),
128+
(2048, 512, 64, 1, "Qwen3 gate/up 64exp M=1"),
129+
(2048, 512, 128, 1, "Qwen3 gate/up 128exp M=1"),
130+
# Qwen3-Coder-Next down expert
131+
(512, 2048, 8, 1, "Qwen3 down 8exp M=1"),
132+
(512, 2048, 8, 4, "Qwen3 down 8exp M=4"),
133+
(512, 2048, 64, 1, "Qwen3 down 64exp M=1"),
134+
# GLM-4.7-Flash routed expert
135+
(2048, 1536, 8, 1, "GLM4.7 routed 8exp M=1"),
136+
(2048, 1536, 8, 4, "GLM4.7 routed 8exp M=4"),
137+
(2048, 1536, 64, 1, "GLM4.7 routed 64exp M=1"),
138+
]
139+
140+
print(f"Grouped Expert GEMM Benchmark: K={k}")
141+
print(f"Warmup={args.warmup}, Iters={args.iters}")
142+
print()
143+
print(f"{'Description':<30} | {'K_dim':>5} {'N':>5} {'#exp':>4} {'M/e':>3} | "
144+
f"{'Grouped(us)':>11} {'Indiv(us)':>10} {'cuBLAS(us)':>10} | "
145+
f"{'vs Indiv':>8} {'vs cuBLAS':>9}")
146+
print("-" * 120)
147+
148+
for K_dim, N, num_experts, M_per_expert, desc in configs:
149+
N_padded = ((N + 127) // 128) * 128
150+
151+
B_packed_all, B_absmax_all, codebook, W_list, packed_list, absmax_list = (
152+
prepare_expert_weights(K_dim, N_padded, k, num_experts)
153+
)
154+
155+
# Build activations
156+
A_list = []
157+
offsets = [0]
158+
for i in range(num_experts):
159+
A_i = torch.randn(M_per_expert, K_dim, dtype=torch.float16, device="cuda")
160+
A_list.append(A_i)
161+
offsets.append(offsets[-1] + M_per_expert)
162+
163+
A_concat = torch.cat(A_list, dim=0)
164+
expert_offsets = torch.tensor(offsets, dtype=torch.int32, device="cuda")
165+
166+
# Benchmark grouped
167+
t_grouped = bench_grouped_gemm(
168+
A_concat, B_packed_all, B_absmax_all, codebook,
169+
expert_offsets, K_dim, N_padded, k, num_experts,
170+
warmup=args.warmup, iters=args.iters,
171+
)
172+
173+
# Benchmark individual kbit
174+
t_individual = bench_individual_kbit(
175+
A_list, packed_list, absmax_list, codebook,
176+
K_dim, N_padded, k,
177+
warmup=args.warmup, iters=args.iters,
178+
)
179+
180+
# Benchmark individual cuBLAS
181+
W_fp16_list = [W.half().cuda() for W in W_list]
182+
t_cublas = bench_individual_cublas(
183+
A_list, W_fp16_list,
184+
warmup=args.warmup, iters=args.iters,
185+
)
186+
187+
speedup_vs_indiv = t_individual / t_grouped
188+
speedup_vs_cublas = t_cublas / t_grouped
189+
190+
print(f"{desc:<30} | {K_dim:5d} {N_padded:5d} {num_experts:4d} {M_per_expert:3d} | "
191+
f"{t_grouped*1e6:11.1f} {t_individual*1e6:10.1f} {t_cublas*1e6:10.1f} | "
192+
f"{speedup_vs_indiv:7.2f}x {speedup_vs_cublas:8.2f}x")
193+
194+
print()
195+
196+
197+
if __name__ == "__main__":
198+
main()

bitsandbytes/_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,33 @@ def _(
599599
torch._check(A.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A.dtype}")
600600
M = A.shape[0]
601601
return torch.empty(M, N, device=A.device, dtype=A.dtype)
602+
603+
604+
# K-bit grouped expert GEMM: batch multiple MoE expert GEMMs into one launch
605+
606+
torch.library.define(
607+
"bitsandbytes::kbit_grouped_gemm",
608+
"(Tensor A_concat, Tensor B_packed_all, Tensor B_absmax_all, Tensor codebook, "
609+
"Tensor expert_offsets, int K_dim, int N, int k, int num_experts) -> Tensor",
610+
)
611+
612+
613+
@register_fake("bitsandbytes::kbit_grouped_gemm")
614+
def _(
615+
A_concat: torch.Tensor,
616+
B_packed_all: torch.Tensor,
617+
B_absmax_all: torch.Tensor,
618+
codebook: torch.Tensor,
619+
expert_offsets: torch.Tensor,
620+
K_dim: int,
621+
N: int,
622+
k: int,
623+
num_experts: int,
624+
) -> torch.Tensor:
625+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
626+
torch._check(A_concat.dim() == 2 and A_concat.shape[1] == K_dim, lambda: "A_concat must be [total_M, K_dim]")
627+
torch._check(
628+
A_concat.dtype in (torch.float16, torch.bfloat16), lambda: f"A must be fp16 or bf16, got {A_concat.dtype}"
629+
)
630+
total_M = A_concat.shape[0]
631+
return torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)

bitsandbytes/backends/cuda/ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,3 +1078,48 @@ def _(
10781078
)
10791079

10801080
return C
1081+
1082+
1083+
@register_kernel("bitsandbytes::kbit_grouped_gemm", "cuda")
1084+
def _(
1085+
A_concat: torch.Tensor,
1086+
B_packed_all: torch.Tensor,
1087+
B_absmax_all: torch.Tensor,
1088+
codebook: torch.Tensor,
1089+
expert_offsets: torch.Tensor,
1090+
K_dim: int,
1091+
N: int,
1092+
k: int,
1093+
num_experts: int,
1094+
) -> torch.Tensor:
1095+
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
1096+
torch._check(
1097+
A_concat.dtype in (torch.float16, torch.bfloat16),
1098+
lambda: f"kbit_grouped_gemm supports float16 and bfloat16, got {A_concat.dtype}",
1099+
)
1100+
torch._check(B_packed_all.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed_all.dtype}")
1101+
torch._check(B_absmax_all.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax_all.dtype}")
1102+
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
1103+
torch._check(expert_offsets.dtype == torch.int32, lambda: f"expert_offsets must be int32, got {expert_offsets.dtype}")
1104+
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
1105+
1106+
total_M = A_concat.shape[0]
1107+
C_concat = torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)
1108+
1109+
dtype_suffix = "fp16" if A_concat.dtype == torch.float16 else "bf16"
1110+
1111+
with _cuda_device_of(A_concat):
1112+
fn = getattr(lib, f"ckbit_grouped_gemm_prod_{dtype_suffix}_k{k}")
1113+
fn(
1114+
get_ptr(A_concat),
1115+
get_ptr(B_packed_all),
1116+
get_ptr(B_absmax_all),
1117+
get_ptr(codebook),
1118+
get_ptr(C_concat),
1119+
get_ptr(expert_offsets),
1120+
ct.c_int(K_dim),
1121+
ct.c_int(N),
1122+
ct.c_int(num_experts),
1123+
)
1124+
1125+
return C_concat

0 commit comments

Comments
 (0)