Skip to content

Commit 3cbaf74

Browse files
TimDettmersclaude
andcommitted
Remove unnecessary E4M4 conversion in dequant, add dequant overhead benchmark
- Add float32 absmax support to dequantize_kbit CUDA kernel (template instantiations + C wrappers), removing the Python-side E4M4 conversion that launched ~15 PyTorch kernels per call. Dequant goes from ~800us to ~30us for large shapes (gateup/down) and ~5us for small (KV). - Add bench_dequant.sh/py: measures dequant kernel time via ncu and fp16 matmul via CUDA events, reports speed ratio (fp16 / total) per shape × k × M. Dequant scales linearly with element count and k. - Update bench_ncu.sh with model-level summary tables and grouped kernel support - Document dequant benchmark in kbit-kernel-spec.md Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 826bc00 commit 3cbaf74

File tree

10 files changed

+700
-95
lines changed

10 files changed

+700
-95
lines changed

benchmarks/bench_dequant.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Dequant + cuBLAS overhead analysis.
2+
3+
Measures dequantize_kbit GPU kernel time per shape×k (via ncu or --use-events),
4+
fp16 matmul time per shape×M, and computes the overhead ratio.
5+
6+
Usage:
7+
# Recommended: ncu for dequant (accurate), CUDA events for matmul
8+
bash benchmarks/bench_dequant.sh
9+
10+
# Quick (CUDA events only, includes ~35us dispatch overhead on dequant):
11+
python benchmarks/bench_dequant.py --use-events
12+
13+
Env: M_VALS (default "4,8,16,32,64,128,256,512,1024,2048,4096")
14+
DEQUANT_CSV: comma-separated dequant times injected by bench_dequant.sh
15+
(order: k=2 × 5 shapes, k=3 × 5, k=4 × 5, k=5 × 5)
16+
"""
17+
import os, sys, argparse
18+
19+
for p in [".", ".."]:
20+
if os.path.isdir(os.path.join(p, "bitsandbytes")):
21+
sys.path.insert(0, os.path.abspath(p))
22+
break
23+
24+
import torch
25+
import bitsandbytes # noqa: E402
26+
from bitsandbytes.functional import create_normal_float_codebook # noqa: E402
27+
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--use-events", action="store_true",
30+
help="Use CUDA events for dequant timing (includes dispatch overhead)")
31+
args = parser.parse_args()
32+
33+
shapes = [
34+
("gateup", 2048, 5120),
35+
("down", 5120, 2048),
36+
("Q", 2048, 4096),
37+
("O", 4096, 2048),
38+
("KV", 2048, 512),
39+
]
40+
k_bits_list = [2, 3, 4, 5]
41+
m_vals = [int(x) for x in os.environ.get(
42+
"M_VALS", "4,8,16,32,64,128,256,512,1024,2048,4096").split(",")]
43+
44+
dev = torch.device("cuda")
45+
start_ev = torch.cuda.Event(enable_timing=True)
46+
end_ev = torch.cuda.Event(enable_timing=True)
47+
WARMUP = 50
48+
ITERS = 200
49+
50+
# --- Dequant times ---
51+
dequant_us = {}
52+
dequant_env = os.environ.get("DEQUANT_CSV", "")
53+
if dequant_env:
54+
# Injected by bench_dequant.sh (ncu-measured)
55+
# Order: k=2 × 5 shapes, k=3 × 5, k=4 × 5, k=5 × 5
56+
vals = [float(x) for x in dequant_env.split(",")]
57+
i = 0
58+
for k in k_bits_list:
59+
for name, _, _ in shapes:
60+
dequant_us[(name, k)] = vals[i]
61+
i += 1
62+
elif args.use_events:
63+
# Fallback: CUDA events (includes ~35us dispatch overhead)
64+
for k in k_bits_list:
65+
codebook = create_normal_float_codebook(k, device=dev)
66+
for name, K_dim, N in shapes:
67+
n_elements = K_dim * N
68+
W = torch.randn(n_elements, device=dev, dtype=torch.float32)
69+
packed, absmax = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
70+
for _ in range(WARMUP):
71+
torch.ops.bitsandbytes.dequantize_kbit(
72+
packed, codebook, absmax, k, n_elements, torch.float16)
73+
torch.cuda.synchronize()
74+
start_ev.record()
75+
for _ in range(ITERS):
76+
torch.ops.bitsandbytes.dequantize_kbit(
77+
packed, codebook, absmax, k, n_elements, torch.float16)
78+
end_ev.record()
79+
torch.cuda.synchronize()
80+
dequant_us[(name, k)] = start_ev.elapsed_time(end_ev) * 1000 / ITERS
81+
else:
82+
print("ERROR: Run via bench_dequant.sh (ncu) or with --use-events", file=sys.stderr)
83+
sys.exit(1)
84+
85+
# --- Print dequant times ---
86+
print("=== Dequant kernel time (us) ===")
87+
print(f"{'shape':<8}", end="")
88+
for k in k_bits_list:
89+
print(f" {'k='+str(k):>8}", end="")
90+
print()
91+
print("---")
92+
for name, _, _ in shapes:
93+
print(f"{name:<8}", end="")
94+
for k in k_bits_list:
95+
print(f" {dequant_us[(name, k)]:>8.1f}", end="")
96+
print()
97+
print()
98+
99+
# --- Measure fp16 matmul time per shape×M ---
100+
matmul_us = {}
101+
for name, K_dim, N in shapes:
102+
W = torch.randn(K_dim, N, dtype=torch.float16, device=dev)
103+
for M in m_vals:
104+
A = torch.randn(M, K_dim, dtype=torch.float16, device=dev)
105+
out = torch.empty(M, N, dtype=torch.float16, device=dev)
106+
for _ in range(WARMUP):
107+
torch.mm(A, W, out=out)
108+
torch.cuda.synchronize()
109+
start_ev.record()
110+
for _ in range(ITERS):
111+
torch.mm(A, W, out=out)
112+
end_ev.record()
113+
torch.cuda.synchronize()
114+
matmul_us[(name, M)] = start_ev.elapsed_time(end_ev) * 1000 / ITERS
115+
116+
# --- Print combined table per k ---
117+
for k in k_bits_list:
118+
print(f"=== k={k}: dequant + fp16 matmul overhead ===")
119+
print(f"{'shape':<8} {'M':>6} {'fp16 (us)':>10} {'dequant (us)':>13} {'total (us)':>11} {'speed':>7}")
120+
print("-" * 60)
121+
for name, K_dim, N in shapes:
122+
d = dequant_us[(name, k)]
123+
for M in m_vals:
124+
mm = matmul_us[(name, M)]
125+
total = d + mm
126+
speed = mm / total
127+
print(f"{name:<8} {M:>6} {mm:>10.1f} {d:>13.1f} {total:>11.1f} {speed:>7.2f}")
128+
print()
129+
print()

benchmarks/bench_dequant.sh

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/bin/bash
2+
# Dequant + cuBLAS overhead analysis.
3+
# Uses ncu for accurate dequant kernel timing, CUDA events for matmul.
4+
#
5+
# Usage:
6+
# bash benchmarks/bench_dequant.sh
7+
# M_VALS=16,32,64,128,256 bash benchmarks/bench_dequant.sh
8+
set -e
9+
10+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
11+
12+
# Phase 1: measure dequant kernel times via ncu (all shapes × all k)
13+
echo "Measuring dequant kernel times via ncu..."
14+
DEQUANT_CSV=$(ncu --kernel-name "kDequantizeBlockwise_kbit_vec" \
15+
--metrics gpu__time_duration.avg \
16+
python3 -c "
17+
import sys, torch; sys.path.insert(0, '.')
18+
import bitsandbytes
19+
from bitsandbytes.functional import create_normal_float_codebook
20+
shapes = [('gateup',2048,5120),('down',5120,2048),('Q',2048,4096),('O',4096,2048),('KV',2048,512)]
21+
dev = torch.device('cuda')
22+
for k in [2,3,4,5]:
23+
codebook = create_normal_float_codebook(k, device=dev)
24+
for name, K, N in shapes:
25+
n = K * N
26+
W = torch.randn(n, device=dev)
27+
packed, absmax = torch.ops.bitsandbytes.quantize_kbit(W, codebook, k)
28+
torch.cuda.synchronize()
29+
for _ in range(3):
30+
torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, torch.float16)
31+
torch.cuda.synchronize()
32+
torch.ops.bitsandbytes.dequantize_kbit(packed, codebook, absmax, k, n, torch.float16)
33+
torch.cuda.synchronize()
34+
" 2>&1 | grep "gpu__time_duration" | awk '{print $NF}' | \
35+
python3 -c "
36+
import sys
37+
vals = [float(l.strip()) for l in sys.stdin]
38+
# 4 launches per (k, shape): 3 warmup + 1 profiled, take last
39+
result = []
40+
for i in range(0, len(vals), 4):
41+
result.append(vals[i+3])
42+
# Output: k=2 × 5 shapes, k=3 × 5, k=4 × 5, k=5 × 5
43+
print(','.join(f'{v:.2f}' for v in result))
44+
")
45+
46+
echo "Dequant kernel times (ncu): $DEQUANT_CSV"
47+
echo ""
48+
49+
# Phase 2: run the Python script with injected dequant times
50+
DEQUANT_CSV="$DEQUANT_CSV" python3 "$SCRIPT_DIR/bench_dequant.py"

benchmarks/bench_fp16.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,71 @@
11
"""cuBLAS fp16 baseline — CUDA event timing, pre-allocated I/O.
22
3-
Env: M_VALS (default "1,2,3,4,8")
3+
Benchmarks dense matmul (torch.mm) and batched MoE matmul (torch.bmm).
4+
5+
Env: M_VALS (default "1,2,3,4,8"), NUM_EXPERTS (default "8")
46
"""
57
import os, torch
68

7-
shapes = [
9+
dense_shapes = [
810
("gateup", 2048, 5120),
911
("down", 5120, 2048),
1012
("Q", 2048, 4096),
1113
("O", 4096, 2048),
1214
("KV", 2048, 512),
1315
]
16+
moe_shapes = [
17+
("moe_gu", 2048, 512),
18+
("moe_dn", 512, 2048),
19+
]
20+
1421
m_vals = [int(x) for x in os.environ.get("M_VALS", "1,2,3,4,8").split(",")]
22+
NUM_EXPERTS = int(os.environ.get("NUM_EXPERTS", "8"))
1523
dev = torch.device("cuda")
1624
start = torch.cuda.Event(enable_timing=True)
1725
end = torch.cuda.Event(enable_timing=True)
1826

27+
WARMUP = 50
28+
ITERS = 200
29+
30+
# --- Dense layers (torch.mm) ---
1931
print(f"{'shape':<8} {'M':>2} {'avg_us':>10}")
2032
print("---")
2133

22-
for name, K, N in shapes:
34+
for name, K, N in dense_shapes:
2335
W = torch.randn(K, N, dtype=torch.float16, device=dev)
2436
for M in m_vals:
2537
A = torch.randn(M, K, dtype=torch.float16, device=dev)
2638
out = torch.empty(M, N, dtype=torch.float16, device=dev)
27-
for _ in range(50):
39+
for _ in range(WARMUP):
2840
torch.mm(A, W, out=out)
2941
torch.cuda.synchronize()
3042
start.record()
31-
for _ in range(200):
43+
for _ in range(ITERS):
3244
torch.mm(A, W, out=out)
3345
end.record()
3446
torch.cuda.synchronize()
35-
us = start.elapsed_time(end) * 1000 / 200
47+
us = start.elapsed_time(end) * 1000 / ITERS
3648
print(f"{name:<8} {M:>2} {us:>10.2f}")
49+
50+
# --- MoE layers (torch.bmm) ---
51+
print()
52+
print(f"{'shape':<8} {'M':>2} {'nexp':>4} {'avg_us':>10}")
53+
print("---")
54+
55+
for name, K, N in moe_shapes:
56+
# Weight: [num_experts, K, N] — each expert has its own weight matrix
57+
W_batch = torch.randn(NUM_EXPERTS, K, N, dtype=torch.float16, device=dev)
58+
for M in m_vals:
59+
# A: [num_experts, M, K] — M tokens per expert
60+
A_batch = torch.randn(NUM_EXPERTS, M, K, dtype=torch.float16, device=dev)
61+
out = torch.empty(NUM_EXPERTS, M, N, dtype=torch.float16, device=dev)
62+
for _ in range(WARMUP):
63+
torch.bmm(A_batch, W_batch, out=out)
64+
torch.cuda.synchronize()
65+
start.record()
66+
for _ in range(ITERS):
67+
torch.bmm(A_batch, W_batch, out=out)
68+
end.record()
69+
torch.cuda.synchronize()
70+
us = start.elapsed_time(end) * 1000 / ITERS
71+
print(f"{name:<8} {M:>2} {NUM_EXPERTS:>4} {us:>10.2f}")

benchmarks/bench_ncu.sh

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,51 @@
11
#!/bin/bash
2-
# Full kernel benchmark: MMA + scalar (ncu) + cuBLAS fp16 (CUDA events).
2+
# Full kernel benchmark: MMA + scalar + grouped (ncu) + cuBLAS fp16 (CUDA events).
3+
# Then computes end-to-end model summary for Qwen3-Coder-Next 70B.
34
#
45
# Usage:
5-
# bash benchmarks/bench_ncu.sh # default M=1,2,3,4,8
6-
# M_VALS=3,4 bash benchmarks/bench_ncu.sh # custom M values
6+
# bash benchmarks/bench_ncu.sh # default M=1..8
7+
# M_VALS=1,4 bash benchmarks/bench_ncu.sh # custom M values
78
#
8-
# Output: three tables (MMA, scalar, cuBLAS fp16) with avg kernel time
9-
# in microseconds for each shape × k × M combination.
9+
# Output: raw kernel tables, then one summary table per M value showing
10+
# all kernels side by side for every (shape, k) combination.
1011
#
11-
# Runtime: ~30-60 seconds depending on M_VALS count.
12+
# Runtime: ~2-4 minutes for M=1..8.
1213
set -e
1314

1415
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
15-
export M_VALS="${M_VALS:-1,2,3,4,8}"
16+
RESULTS_DIR="$SCRIPT_DIR/.bench_results"
17+
mkdir -p "$RESULTS_DIR"
18+
19+
export M_VALS="${M_VALS:-1,2,3,4,5,6,7,8}"
20+
export NUM_EXPERTS="${NUM_EXPERTS:-8}"
1621
WARMUP=5
1722
PROFILED=5
1823

24+
# Compute M subsets: scalar/grouped only support M<=4
25+
SCALAR_M=$(python3 -c "print(','.join(str(m) for m in [int(x) for x in '$M_VALS'.split(',')] if m <= 4))")
26+
ALL_M="$M_VALS"
27+
1928
echo "START: $(date)"
20-
echo "M values: $M_VALS"
21-
22-
for KERNEL in mma scalar; do
23-
if [ "$KERNEL" = "mma" ]; then
24-
KNAME="kbit_gemm_prod"
25-
echo ""
26-
echo "=== MMA kernel ==="
27-
else
28-
KNAME="kbit_scalar_gemv"
29-
echo ""
30-
echo "=== Scalar GEMV ==="
31-
fi
32-
printf "%-8s %2s %2s %10s\n" "shape" "k" "M" "avg_us"
33-
echo "---"
34-
35-
KERNEL=$KERNEL M_VALS=$M_VALS ncu --kernel-name "$KNAME" --metrics gpu__time_duration.avg \
29+
echo "M values: $M_VALS (scalar/grouped: $SCALAR_M)"
30+
echo "MoE experts: $NUM_EXPERTS"
31+
32+
# Helper: run ncu and parse output for a kernel
33+
run_ncu_bench() {
34+
local KTYPE="$1" # mma, scalar, grouped
35+
local KNAME="$2" # ncu kernel name filter
36+
local SHAPES="$3" # Python list literal for shape names
37+
local MVALS="$4" # M values to use
38+
39+
KERNEL=$KTYPE M_VALS=$MVALS NUM_EXPERTS=$NUM_EXPERTS \
40+
ncu --kernel-name "$KNAME" --metrics gpu__time_duration.avg \
3641
python "$SCRIPT_DIR/ncu_driver.py" 2>/dev/null | \
3742
grep "gpu__time_duration.avg" | awk '{print $NF}' | \
3843
python3 -c "
39-
import os, sys
44+
import sys
4045
vals = [float(l.strip()) for l in sys.stdin]
41-
shapes = ['gateup','down','Q','O','KV']
46+
shapes = $SHAPES
4247
kbits = [2,3,4,5]
43-
mvals = [int(x) for x in os.environ['M_VALS'].split(',')]
48+
mvals = [int(x) for x in '$MVALS'.split(',')]
4449
W, P = $WARMUP, $PROFILED
4550
i = 0
4651
for s in shapes:
@@ -51,12 +56,47 @@ for s in shapes:
5156
print(f'{s:<8} {k:>2} {m:>2} {avg:>10.2f}')
5257
i += W + P
5358
"
54-
done
59+
}
60+
61+
# ---- MMA kernel (all M values) ----
62+
echo ""
63+
echo "=== MMA kernel ==="
64+
printf "%-8s %2s %2s %10s\n" "shape" "k" "M" "avg_us"
65+
echo "---"
66+
run_ncu_bench mma "kbit_gemm_prod" "['gateup','down','Q','O','KV']" "$ALL_M" | tee "$RESULTS_DIR/mma.txt"
67+
68+
# ---- Scalar GEMV (M<=4 only) ----
69+
echo ""
70+
echo "=== Scalar GEMV (M<=4) ==="
71+
printf "%-8s %2s %2s %10s\n" "shape" "k" "M" "avg_us"
72+
echo "---"
73+
if [ -n "$SCALAR_M" ]; then
74+
run_ncu_bench scalar "kbit_scalar_gemv" "['gateup','down','Q','O','KV']" "$SCALAR_M" | tee "$RESULTS_DIR/scalar.txt"
75+
else
76+
echo "(no M<=4 values requested)" | tee "$RESULTS_DIR/scalar.txt"
77+
fi
78+
79+
# ---- Grouped expert kernel (M<=4 only) ----
80+
echo ""
81+
echo "=== Grouped scalar GEMV (${NUM_EXPERTS} experts, M<=4) ==="
82+
printf "%-8s %2s %2s %10s\n" "shape" "k" "M" "avg_us"
83+
echo "---"
84+
if [ -n "$SCALAR_M" ]; then
85+
run_ncu_bench grouped "kbit_grouped_scalar_gemv" "['moe_gu','moe_dn']" "$SCALAR_M" | tee "$RESULTS_DIR/grouped.txt"
86+
else
87+
echo "(no M<=4 values requested)" | tee "$RESULTS_DIR/grouped.txt"
88+
fi
89+
90+
# ---- cuBLAS fp16 baselines (CUDA events, all M values) ----
91+
echo ""
92+
echo "=== cuBLAS fp16 (dense mm + MoE bmm) ==="
93+
M_VALS=$ALL_M NUM_EXPERTS=$NUM_EXPERTS python "$SCRIPT_DIR/bench_fp16.py" 2>/dev/null | \
94+
tee "$RESULTS_DIR/cublas.txt"
5595

56-
# cuBLAS fp16 (CUDA events — ncu can't reliably filter cuBLAS kernels)
96+
# ---- Model-level summary ----
5797
echo ""
58-
echo "=== cuBLAS fp16 ==="
59-
M_VALS=$M_VALS python "$SCRIPT_DIR/bench_fp16.py" 2>/dev/null
98+
echo "=== Qwen3-Coder-Next 70B: weight matmul summary ==="
99+
python3 "$SCRIPT_DIR/model_summary.py" "$RESULTS_DIR"
60100

61101
echo ""
62102
echo "END: $(date)"

0 commit comments

Comments
 (0)