Skip to content

Commit b02b657

Browse files
TimDettmersclaude
andcommitted
Remove dead warpspec/dqonce kernels, add deployment analysis docs
Remove the warp-specialized and dequant-once grouped GEMM kernels from ops.cu — both were correct but slower than the baseline on Ada due to register pressure from multiple accumulator sets. Also remove the unused get_num_sms() helper. See moe-kernel-spec.md for the full post-mortem. Add deployment-summary.md with per-kernel performance tables at M=1/4/64+, workload-weighted vLLM analysis, and memory savings. Add moe-kernel-spec.md documenting the MoE optimization attempts and hybrid dequant+cuBLAS BMM benchmarks. Update kbit-kernel-spec.md with grouped scalar GEMV section and current optimization priorities. 226/226 tests pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 23f92e5 commit b02b657

File tree

7 files changed

+1174
-87
lines changed

7 files changed

+1174
-87
lines changed

benchmarks/bench_fp16_moe_sweep.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""fp16 BMM baseline for MoE shapes across wide M range.
2+
3+
Uses CUDA events (accurate for fp16 bmm which has no Python overhead).
4+
"""
5+
import torch
6+
7+
NUM_EXPERTS = 8
8+
WARMUP = 50
9+
ITERS = 200
10+
11+
dev = torch.device("cuda")
12+
13+
shapes = [
14+
("moe_gu", 2048, 512),
15+
("moe_dn", 512, 2048),
16+
]
17+
18+
m_vals = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
19+
20+
print(f"{'shape':<8} {'M':>5} {'fp16_us':>8}")
21+
print("-" * 24)
22+
23+
for name, K_dim, N in shapes:
24+
for M in m_vals:
25+
A = torch.randn(NUM_EXPERTS, M, K_dim, dtype=torch.float16, device=dev)
26+
B = torch.randn(NUM_EXPERTS, K_dim, N, dtype=torch.float16, device=dev)
27+
28+
fn = lambda: torch.bmm(A, B)
29+
for _ in range(WARMUP):
30+
fn()
31+
torch.cuda.synchronize()
32+
33+
start = torch.cuda.Event(enable_timing=True)
34+
end = torch.cuda.Event(enable_timing=True)
35+
start.record()
36+
for _ in range(ITERS):
37+
fn()
38+
end.record()
39+
torch.cuda.synchronize()
40+
t = start.elapsed_time(end) / ITERS * 1000 # us
41+
print(f"{name:<8} {M:>5} {t:>8.1f}")
42+
print()

benchmarks/ncu_moe_sweep.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""NCU driver for MoE grouped MMA sweep across wide M range.
2+
3+
Only k=4, but all power-of-2 M values from 1 to 4096.
4+
Usage: ncu --kernel-name "kbit_grouped_gemm_prod" --metrics gpu__time_duration.avg python benchmarks/ncu_moe_sweep.py
5+
"""
6+
import os, sys, torch
7+
8+
for p in [".", ".."]:
9+
if os.path.isdir(os.path.join(p, "bitsandbytes")):
10+
sys.path.insert(0, os.path.abspath(p))
11+
break
12+
13+
import bitsandbytes
14+
from bitsandbytes.functional import create_normal_float_codebook
15+
16+
NUM_EXPERTS = 8
17+
K_BITS = 4
18+
WARMUP = 3
19+
PROFILED = 5
20+
21+
dev = torch.device("cuda")
22+
codebook = create_normal_float_codebook(K_BITS, device=dev)
23+
24+
shapes = [
25+
("moe_gu", 2048, 512),
26+
("moe_dn", 512, 2048),
27+
]
28+
29+
m_vals = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
30+
31+
# Pre-quantize
32+
moe_data = {}
33+
for name, K_dim, N in shapes:
34+
packed_list, absmax_list = [], []
35+
for _ in range(NUM_EXPERTS):
36+
W = torch.randn(K_dim * N, device=dev, dtype=torch.float32)
37+
pf, af = torch.ops.bitsandbytes.quantize_kbit(W, codebook, K_BITS)
38+
pt, at = torch.ops.bitsandbytes.repack_kbit(pf, af, K_dim, N, K_BITS)
39+
packed_list.append(pt)
40+
absmax_list.append(at)
41+
B_packed_all = torch.cat(packed_list, dim=0)
42+
B_absmax_all = torch.cat(absmax_list, dim=0)
43+
moe_data[name] = (K_dim, N, B_packed_all, B_absmax_all)
44+
45+
# Print config to stderr
46+
print(f"shapes={[s[0] for s in shapes]} k={K_BITS} M={m_vals} W={WARMUP} P={PROFILED}", file=sys.stderr)
47+
48+
for name, K_dim, N in shapes:
49+
K_dim, N, B_packed_all, B_absmax_all = moe_data[name]
50+
for M in m_vals:
51+
total_tokens = M * NUM_EXPERTS
52+
A_concat = torch.randn(total_tokens, K_dim, dtype=torch.float16, device=dev)
53+
offsets = list(range(0, total_tokens + 1, M))
54+
expert_offsets = torch.tensor(offsets, dtype=torch.int32, device=dev)
55+
56+
fn = lambda: torch.ops.bitsandbytes.kbit_grouped_gemm(
57+
A_concat, B_packed_all, B_absmax_all, codebook,
58+
expert_offsets, K_dim, N, K_BITS, NUM_EXPERTS, M)
59+
60+
for _ in range(WARMUP):
61+
fn()
62+
torch.cuda.synchronize()
63+
for _ in range(PROFILED):
64+
fn()
65+
torch.cuda.synchronize()

benchmarks/ncu_single_moe.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Single MoE kernel invocation for detailed NCU profiling."""
2+
import torch, sys
3+
sys.path.insert(0, ".")
4+
import bitsandbytes
5+
from bitsandbytes.functional import quantize_kbit, create_normal_float_codebook
6+
7+
torch.manual_seed(42)
8+
k, K_dim, N, num_experts, M = 4, 2048, 512, 8, 512
9+
codebook = create_normal_float_codebook(k, device="cuda")
10+
11+
packed_list, absmax_list = [], []
12+
for e in range(num_experts):
13+
W = torch.randn(N, K_dim, dtype=torch.float16, device="cuda")
14+
packed, absmax, _ = quantize_kbit(W, k, codebook=codebook, absmax_format="fp32")
15+
packed_list.append(packed)
16+
absmax_list.append(absmax)
17+
18+
B_packed_list, B_absmax_list = [], []
19+
for e in range(num_experts):
20+
bp, ba = torch.ops.bitsandbytes.repack_kbit(packed_list[e], absmax_list[e], K_dim, N, k)
21+
B_packed_list.append(bp)
22+
B_absmax_list.append(ba)
23+
24+
B_packed_all = torch.cat(B_packed_list)
25+
B_absmax_all = torch.cat(B_absmax_list)
26+
27+
A_list, offsets = [], [0]
28+
for e in range(num_experts):
29+
A = torch.randn(M, K_dim, dtype=torch.float16, device="cuda")
30+
A_list.append(A)
31+
offsets.append(offsets[-1] + M)
32+
A_concat = torch.cat(A_list, dim=0)
33+
eo = torch.tensor(offsets, dtype=torch.int32, device="cuda")
34+
35+
# Warmup
36+
for _ in range(3):
37+
C = torch.ops.bitsandbytes.kbit_grouped_gemm(
38+
A_concat, B_packed_all, B_absmax_all, codebook, eo, K_dim, N, k, num_experts, M)
39+
torch.cuda.synchronize()
40+
41+
# Profiled call
42+
C = torch.ops.bitsandbytes.kbit_grouped_gemm(
43+
A_concat, B_packed_all, B_absmax_all, codebook, eo, K_dim, N, k, num_experts, M)
44+
torch.cuda.synchronize()

csrc/ops.cu

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,6 +2507,11 @@ __global__ void kbit_grouped_gemm_prod(
25072507
} // end persistent work loop
25082508
}
25092509

2510+
// [REMOVED: Warp-specialized and dequant-once grouped GEMM kernels.
2511+
// Both were correct but slower than the baseline on Ada (sm_89) due to
2512+
// register pressure from multiple accumulator sets. See moe-kernel-spec.md
2513+
// for the full analysis. Code removed in dead-code cleanup.]
2514+
25102515
// Grouped GEMM launcher — supports TILE_N=64/128 and auto k_splits
25112516
template <int K, int MB, int TN, typename scalar_t>
25122517
static void kbitGroupedGemmProdLaunch(
@@ -2605,16 +2610,6 @@ void kbitGroupedGemmProd(
26052610
}
26062611
}
26072612

2608-
// Cached SM count to avoid repeated cudaGetDevice/cudaDeviceGetAttribute calls
2609-
static int cached_num_sms = 0;
2610-
static int get_num_sms() {
2611-
if (cached_num_sms == 0) {
2612-
int dev;
2613-
cudaGetDevice(&dev);
2614-
cudaDeviceGetAttribute(&cached_num_sms, cudaDevAttrMultiProcessorCount, dev);
2615-
}
2616-
return cached_num_sms;
2617-
}
26182613

26192614
// ===================================================================
26202615
// Scalar GEMV kernel: C[M,N] = A[M,K_dim] * W_kbit^T (M=1..4)

0 commit comments

Comments
 (0)