Skip to content

Commit b02ff66

Browse files
TimDettmersclaude
andcommitted
V8 grouped scalar GEMV + inline work distribution for grouped MMA
Grouped scalar GEMV: ported V8 optimizations (64 threads, 2 warps, vectorized int4 A loads, __launch_bounds__, M_VAL dispatch 1-4) and switched from tiled/E4M4 to flat layout with float32 absmax. Grouped MMA: replaced cudaMemcpy/cudaMalloc/cudaFree work_offsets computation with inline linear scan over expert_offsets. Caller now passes max_M directly, eliminating device-to-host sync per call. Benchmark suite: added grouped_mma kernel type to ncu_driver and bench_ncu.sh. model_summary.py now shows 5 kernel columns (MMA, Scalar, Grouped, Grp MMA, fp16) with per-k TOTAL rows. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3cbaf74 commit b02ff66

File tree

9 files changed

+299
-183
lines changed

9 files changed

+299
-183
lines changed

benchmarks/bench_ncu.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ else
8787
echo "(no M<=4 values requested)" | tee "$RESULTS_DIR/grouped.txt"
8888
fi
8989

90+
# ---- Grouped MMA kernel (all M values) ----
91+
echo ""
92+
echo "=== Grouped MMA (${NUM_EXPERTS} experts, all M) ==="
93+
printf "%-8s %2s %2s %10s\n" "shape" "k" "M" "avg_us"
94+
echo "---"
95+
run_ncu_bench grouped_mma "kbit_grouped_gemm_prod" "['moe_gu','moe_dn']" "$ALL_M" | tee "$RESULTS_DIR/grouped_mma.txt"
96+
9097
# ---- cuBLAS fp16 baselines (CUDA events, all M values) ----
9198
echo ""
9299
echo "=== cuBLAS fp16 (dense mm + MoE bmm) ==="

benchmarks/model_summary.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
side by side, the best kernel, and speedup vs cuBLAS fp16.
66
77
Dense shapes have MMA, Scalar, fp16 columns.
8-
MoE shapes have Grouped, fp16 (bmm) columns.
8+
MoE shapes have Grouped (scalar), Grp MMA, fp16 (bmm) columns.
99
"""
1010
import os, sys
1111

@@ -64,9 +64,10 @@ def main():
6464
mma = parse_results(os.path.join(results_dir, "mma.txt"))
6565
scalar = parse_results(os.path.join(results_dir, "scalar.txt"))
6666
grouped = parse_results(os.path.join(results_dir, "grouped.txt"))
67+
grouped_mma = parse_results(os.path.join(results_dir, "grouped_mma.txt"))
6768
cublas_dense, cublas_moe = parse_cublas(os.path.join(results_dir, "cublas.txt"))
6869

69-
if not mma and not scalar and not grouped:
70+
if not mma and not scalar and not grouped and not grouped_mma:
7071
print("No benchmark results found. Run bench_ncu.sh first.")
7172
return
7273

@@ -78,25 +79,22 @@ def main():
7879

7980
# Collect all M values
8081
all_M = set()
81-
for key in list(mma.keys()) + list(scalar.keys()) + list(grouped.keys()):
82-
all_M.add(key[2])
82+
for d in [mma, scalar, grouped, grouped_mma]:
83+
for key in d:
84+
all_M.add(key[2])
8385
all_M = sorted(all_M)
8486

85-
# Column widths
87+
# Column widths — 6 kernel columns
8688
SEP = "+"
87-
HDR = (f"{SEP}--------+-----+-------+--------+---------+-------+--------+---------{SEP}")
88-
TOP = (f"{SEP}========+=====+=======+========+=========+=======+========+========={SEP}")
89+
HDR = f"{SEP}--------+-----+-------+--------+---------+---------+-------+--------+---------{SEP}"
90+
TOP = f"{SEP}========+=====+=======+========+=========+=========+=======+========+========={SEP}"
8991

9092
for M in all_M:
9193
print(f"\n M={M}:")
9294
print(f" {TOP}")
93-
print(f" | {'shape':<6} | {'k':>3} | {'MMA':>5} | {'Scalar':>6} | {'Grouped':>7} | {'fp16':>5} | {'Best':>6} | {'vs fp16':>7} |")
95+
print(f" | {'shape':<6} | {'k':>3} | {'MMA':>5} | {'Scalar':>6} | {'Grouped':>7} | {'Grp MMA':>7} | {'fp16':>5} | {'Best':>6} | {'vs fp16':>7} |")
9496
print(f" {HDR}")
9597

96-
total_best = 0.0
97-
total_fp16 = 0.0
98-
all_complete = True
99-
10098
for shape in all_shapes:
10199
is_moe = shape in moe_shapes
102100

@@ -105,6 +103,7 @@ def main():
105103
m_us = mma.get((shape, k, M)) if not is_moe else None
106104
s_us = scalar.get((shape, k, M)) if not is_moe else None
107105
g_us = grouped.get((shape, k, M)) if is_moe else None
106+
gm_us = grouped_mma.get((shape, k, M)) if is_moe else None
108107
fp16 = cublas_moe.get((shape, M)) if is_moe else cublas_dense.get((shape, M))
109108

110109
# Find best kbit kernel
@@ -115,40 +114,32 @@ def main():
115114
candidates["Scalar"] = s_us
116115
if g_us is not None:
117116
candidates["Grouped"] = g_us
117+
if gm_us is not None:
118+
candidates["Grp MMA"] = gm_us
118119

119120
if candidates:
120121
best_name = min(candidates, key=candidates.get)
121122
best_us = candidates[best_name]
122123
else:
123124
best_name, best_us = None, None
124125

125-
# Compare against fp16.
126-
# "Best" = fastest kbit kernel. "vs fp16" = fp16 / Best.
127-
# >1.00x means kbit wins, <1.00x means fp16 wins (slowdown).
128-
# When no kbit kernel exists, Best falls back to fp16 and shows "-".
129126
if best_us is not None and fp16 is not None:
130127
speedup = f"{fp16 / best_us:5.2f}x"
131-
total_best += best_us
132-
total_fp16 += fp16
133128
elif fp16 is not None and best_us is None:
134-
# No kbit kernel for this config — fp16 only
135129
best_name = "-"
136130
best_us = fp16
137131
speedup = " -"
138-
total_best += fp16
139-
total_fp16 += fp16
140132
else:
141133
speedup = " N/A"
142-
all_complete = False
143134

144135
best_str = best_name if best_name else "N/A"
145136

146-
print(f" | {shape:<6} | {k:>3} | {fmt(m_us)} | {fmt(s_us):>6} | {fmt(g_us):>7} | {fmt(fp16)} | {best_str:>6} | {speedup:>7} |")
137+
print(f" | {shape:<6} | {k:>3} | {fmt(m_us)} | {fmt(s_us):>6} | {fmt(g_us):>7} | {fmt(gm_us):>7} | {fmt(fp16)} | {best_str:>7} | {speedup:>7} |")
147138

148139
print(f" {HDR}")
149140

150141
# Per-k total rows: sum best_us and fp16 across all shapes for each k
151-
print(f" | {'TOTAL':<6} | | | | | | | |")
142+
print(f" | {'TOTAL':<6} | | | | | | | | |")
152143
for k in k_bits:
153144
k_best = 0.0
154145
k_fp16 = 0.0
@@ -158,6 +149,7 @@ def main():
158149
m_us = mma.get((shape, k, M)) if not is_moe else None
159150
s_us = scalar.get((shape, k, M)) if not is_moe else None
160151
g_us = grouped.get((shape, k, M)) if is_moe else None
152+
gm_us = grouped_mma.get((shape, k, M)) if is_moe else None
161153
fp16 = cublas_moe.get((shape, M)) if is_moe else cublas_dense.get((shape, M))
162154

163155
candidates = {}
@@ -167,6 +159,8 @@ def main():
167159
candidates["Scalar"] = s_us
168160
if g_us is not None:
169161
candidates["Grouped"] = g_us
162+
if gm_us is not None:
163+
candidates["Grp MMA"] = gm_us
170164

171165
if candidates:
172166
best_us = min(candidates.values())
@@ -182,9 +176,9 @@ def main():
182176

183177
if k_complete and k_best > 0 and k_fp16 > 0:
184178
overall = k_fp16 / k_best
185-
print(f" | k={k:<3} | {k:>3} | | | | | {k_best:6.1f} | {overall:5.2f}x |")
179+
print(f" | k={k:<3} | {k:>3} | | | | | | {k_best:6.1f} | {overall:5.2f}x |")
186180
else:
187-
print(f" | k={k:<3} | {k:>3} | | | | | N/A | N/A |")
181+
print(f" | k={k:<3} | {k:>3} | | | | | | N/A | N/A |")
188182
print(f" {TOP}")
189183

190184

benchmarks/ncu_driver.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""ncu kernel driver — runs all shape x k x M configs in a single process.
22
33
Used by bench_ncu.sh. Env vars:
4-
KERNEL: "mma", "scalar", or "grouped"
4+
KERNEL: "mma", "scalar", "grouped", or "grouped_mma"
55
M_VALS: comma-separated M values (default "1,2,3,4,5,6,7,8")
6-
NUM_EXPERTS: number of active experts for grouped kernel (default 8)
6+
NUM_EXPERTS: number of active experts for grouped/grouped_mma kernel (default 8)
77
88
Each config runs WARMUP + PROFILED kernel launches. ncu captures all
99
matching launches; the sweep script skips warmup and averages profiled.
@@ -27,7 +27,7 @@
2727
m_vals = [int(x) for x in os.environ.get("M_VALS", "1,2,3,4,5,6,7,8").split(",")]
2828
NUM_EXPERTS = int(os.environ.get("NUM_EXPERTS", "8"))
2929

30-
# Scalar and grouped kernels only support M<=4
30+
# Scalar and grouped scalar kernels only support M<=4
3131
if KERNEL in ("scalar", "grouped"):
3232
m_vals = [m for m in m_vals if m <= 4]
3333

@@ -94,7 +94,52 @@
9494
torch.cuda.synchronize()
9595

9696
elif KERNEL == "grouped":
97-
# Pre-quantize MoE expert weights (NUM_EXPERTS copies per shape)
97+
# Pre-quantize MoE expert weights (NUM_EXPERTS copies, flat layout)
98+
moe_data = {}
99+
for name, K_dim, N in moe_shapes:
100+
for k in k_bits_list:
101+
codebook = create_normal_float_codebook(k, device=dev)
102+
packed_list = []
103+
absmax_list = []
104+
num_k_blocks = K_dim // 32
105+
expected_packed = N * num_k_blocks * k
106+
expected_absmax = N * num_k_blocks
107+
for _ in range(NUM_EXPERTS):
108+
W = torch.randn(K_dim * N, device=dev, dtype=torch.float32)
109+
pf, af = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
110+
packed_list.append(pf[:expected_packed])
111+
absmax_list.append(af.cuda()[:expected_absmax])
112+
B_packed_all = torch.cat(packed_list, dim=0)
113+
B_absmax_all = torch.cat(absmax_list, dim=0)
114+
moe_data[(name, k)] = (K_dim, N, B_packed_all, B_absmax_all, codebook)
115+
116+
configs = []
117+
for name, K_dim, N in moe_shapes:
118+
for k in k_bits_list:
119+
for M in m_vals:
120+
configs.append((name, k, M))
121+
122+
for name, k, M in configs:
123+
K_dim, N, B_packed_all, B_absmax_all, codebook = moe_data[(name, k)]
124+
# M tokens per expert (all experts get same M for benchmarking)
125+
total_tokens = M * NUM_EXPERTS
126+
A_concat = torch.randn(total_tokens, K_dim, dtype=torch.float16, device=dev)
127+
offsets = list(range(0, total_tokens + 1, M))
128+
expert_offsets = torch.tensor(offsets, dtype=torch.int32, device=dev)
129+
130+
fn = lambda: torch.ops.bitsandbytes.kbit_grouped_scalar_gemv(
131+
A_concat, B_packed_all, B_absmax_all, codebook,
132+
expert_offsets, K_dim, N, k, NUM_EXPERTS, M)
133+
134+
for _ in range(WARMUP):
135+
fn()
136+
torch.cuda.synchronize()
137+
for _ in range(PROFILED):
138+
fn()
139+
torch.cuda.synchronize()
140+
141+
elif KERNEL == "grouped_mma":
142+
# Pre-quantize MoE expert weights (NUM_EXPERTS copies, tiled layout for MMA)
98143
moe_data = {}
99144
for name, K_dim, N in moe_shapes:
100145
for k in k_bits_list:
@@ -119,15 +164,14 @@
119164

120165
for name, k, M in configs:
121166
K_dim, N, B_packed_all, B_absmax_all, codebook = moe_data[(name, k)]
122-
# M tokens per expert (all experts get same M for benchmarking)
123167
total_tokens = M * NUM_EXPERTS
124168
A_concat = torch.randn(total_tokens, K_dim, dtype=torch.float16, device=dev)
125169
offsets = list(range(0, total_tokens + 1, M))
126170
expert_offsets = torch.tensor(offsets, dtype=torch.int32, device=dev)
127171

128-
fn = lambda: torch.ops.bitsandbytes.kbit_grouped_scalar_gemv(
172+
fn = lambda: torch.ops.bitsandbytes.kbit_grouped_gemm(
129173
A_concat, B_packed_all, B_absmax_all, codebook,
130-
expert_offsets, K_dim, N, k, NUM_EXPERTS)
174+
expert_offsets, K_dim, N, k, NUM_EXPERTS, M)
131175

132176
for _ in range(WARMUP):
133177
fn()

bitsandbytes/_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _(
606606
torch.library.define(
607607
"bitsandbytes::kbit_grouped_gemm",
608608
"(Tensor A_concat, Tensor B_packed_all, Tensor B_absmax_all, Tensor codebook, "
609-
"Tensor expert_offsets, int K_dim, int N, int k, int num_experts) -> Tensor",
609+
"Tensor expert_offsets, int K_dim, int N, int k, int num_experts, int max_M) -> Tensor",
610610
)
611611

612612

@@ -621,6 +621,7 @@ def _(
621621
N: int,
622622
k: int,
623623
num_experts: int,
624+
max_M: int,
624625
) -> torch.Tensor:
625626
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
626627
torch._check(A_concat.dim() == 2 and A_concat.shape[1] == K_dim, lambda: "A_concat must be [total_M, K_dim]")
@@ -682,7 +683,7 @@ def _(
682683
torch.library.define(
683684
"bitsandbytes::kbit_grouped_scalar_gemv",
684685
"(Tensor A_concat, Tensor B_packed_all, Tensor B_absmax_all, Tensor codebook, "
685-
"Tensor expert_offsets, int K_dim, int N, int k, int num_experts) -> Tensor",
686+
"Tensor expert_offsets, int K_dim, int N, int k, int num_experts, int max_M) -> Tensor",
686687
)
687688

688689

@@ -697,6 +698,7 @@ def _(
697698
N: int,
698699
k: int,
699700
num_experts: int,
701+
max_M: int,
700702
) -> torch.Tensor:
701703
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
702704
torch._check(A_concat.dim() == 2 and A_concat.shape[1] == K_dim, lambda: "A_concat must be [total_M, K_dim]")

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ def _(
10851085
N: int,
10861086
k: int,
10871087
num_experts: int,
1088+
max_M: int,
10881089
) -> torch.Tensor:
10891090
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
10901091
torch._check(
@@ -1114,6 +1115,7 @@ def _(
11141115
ct.c_int(K_dim),
11151116
ct.c_int(N),
11161117
ct.c_int(num_experts),
1118+
ct.c_int(max_M),
11171119
)
11181120

11191121
return C_concat
@@ -1193,17 +1195,17 @@ def _(
11931195
N: int,
11941196
k: int,
11951197
num_experts: int,
1198+
max_M: int,
11961199
) -> torch.Tensor:
11971200
torch._check(k >= 2 and k <= 5, lambda: f"k must be 2-5, got {k}")
11981201
torch._check(
11991202
A_concat.dtype in (torch.float16, torch.bfloat16),
12001203
lambda: f"kbit_grouped_scalar_gemv supports float16 and bfloat16, got {A_concat.dtype}",
12011204
)
12021205
torch._check(B_packed_all.dtype == torch.int32, lambda: f"B_packed must be int32, got {B_packed_all.dtype}")
1203-
torch._check(B_absmax_all.dtype == torch.uint8, lambda: f"B_absmax must be uint8 (E4M4), got {B_absmax_all.dtype}")
1206+
torch._check(B_absmax_all.dtype == torch.float32, lambda: f"B_absmax must be float32, got {B_absmax_all.dtype}")
12041207
torch._check(codebook.dtype == torch.float32, lambda: f"codebook must be float32, got {codebook.dtype}")
12051208
torch._check(expert_offsets.dtype == torch.int32, lambda: f"expert_offsets must be int32, got {expert_offsets.dtype}")
1206-
torch._check(N % 128 == 0, lambda: f"N ({N}) must be divisible by 128")
12071209

12081210
total_M = A_concat.shape[0]
12091211
C_concat = torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)
@@ -1222,6 +1224,7 @@ def _(
12221224
ct.c_int(K_dim),
12231225
ct.c_int(N),
12241226
ct.c_int(num_experts),
1227+
ct.c_int(max_M),
12251228
)
12261229

12271230
return C_concat

0 commit comments

Comments
 (0)