1414import torch
1515
1616sys .path .insert (0 , "." )
17- import bitsandbytes # noqa: E402
18- from bitsandbytes import _ops # noqa: E402, F401
19- from bitsandbytes . functional import encode_absmax_e4m4 # noqa: E402
20- from scipy . stats import norm # noqa: E402
17+ from scipy . stats import norm
18+
19+ from bitsandbytes import _ops # noqa: F401
20+ from bitsandbytes . functional import encode_absmax_e4m4
2121
2222
2323def create_normal_float_codebook (k : int ) -> torch .Tensor :
@@ -41,15 +41,14 @@ def bench(fn, warmup=30, iters=300):
4141
4242# ─── Dense layer benchmarks (varying M) ────────────────────────────────────
4343
44+
4445def bench_dense_crossover (K_dim , N , k , codebook , M_values ):
4546 """Benchmark fused kbit GEMM vs dequant+cuBLAS vs cuBLAS-only at varying M."""
4647 N_padded = ((N + 127 ) // 128 ) * 128
4748
4849 # Quantize weight
4950 W = torch .randn (N_padded , K_dim , dtype = torch .float16 , device = "cuda" )
50- packed_flat , absmax_flat = torch .ops .bitsandbytes .quantize_kbit (
51- W .reshape (- 1 ), codebook , k
52- )
51+ packed_flat , absmax_flat = torch .ops .bitsandbytes .quantize_kbit (W .reshape (- 1 ), codebook , k )
5352 # repack_kbit expects fp32 absmax (does its own E4M4 encoding)
5453 packed_tiled , absmax_tiled = torch .ops .bitsandbytes .repack_kbit (
5554 packed_flat , absmax_flat .cuda (), K_dim , N_padded , k
@@ -66,41 +65,63 @@ def bench_dense_crossover(K_dim, N, k, codebook, M_values):
6665 A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
6766
6867 # 1. Fused kbit GEMM
69- t_fused = bench (lambda : torch .ops .bitsandbytes .kbit_gemm (
70- A , packed_tiled , absmax_tiled , codebook , K_dim , N_padded , k ,
71- ))
68+ t_fused = bench (
69+ lambda : torch .ops .bitsandbytes .kbit_gemm (
70+ A ,
71+ packed_tiled ,
72+ absmax_tiled ,
73+ codebook ,
74+ K_dim ,
75+ N_padded ,
76+ k ,
77+ )
78+ )
7279
7380 # 2. cuBLAS fp16 (baseline — assumes weights already in fp16)
7481 t_cublas = bench (lambda : torch .mm (A , W_fp16 ))
7582
7683 # 3. Dequant + cuBLAS (absmax already E4M4, no re-encoding)
7784 def dequant_then_mm ():
7885 deq = torch .ops .bitsandbytes .dequantize_kbit (
79- packed_flat , codebook , absmax_e4m4 ,
80- k , n_elements , torch .float16 ,
86+ packed_flat ,
87+ codebook ,
88+ absmax_e4m4 ,
89+ k ,
90+ n_elements ,
91+ torch .float16 ,
8192 )
8293 return torch .mm (A , deq .view (N_padded , K_dim ).T )
94+
8395 t_dq_mm = bench (dequant_then_mm )
8496
8597 # 4. Just the dequant (to see its cost)
86- t_dq_only = bench (lambda : torch .ops .bitsandbytes .dequantize_kbit (
87- packed_flat , codebook , absmax_e4m4 ,
88- k , n_elements , torch .float16 ,
89- ))
90-
91- results .append ({
92- "M" : M ,
93- "fused_us" : t_fused * 1e6 ,
94- "cublas_us" : t_cublas * 1e6 ,
95- "dq_mm_us" : t_dq_mm * 1e6 ,
96- "dq_only_us" : t_dq_only * 1e6 ,
97- })
98+ t_dq_only = bench (
99+ lambda : torch .ops .bitsandbytes .dequantize_kbit (
100+ packed_flat ,
101+ codebook ,
102+ absmax_e4m4 ,
103+ k ,
104+ n_elements ,
105+ torch .float16 ,
106+ )
107+ )
108+
109+ results .append (
110+ {
111+ "M" : M ,
112+ "fused_us" : t_fused * 1e6 ,
113+ "cublas_us" : t_cublas * 1e6 ,
114+ "dq_mm_us" : t_dq_mm * 1e6 ,
115+ "dq_only_us" : t_dq_only * 1e6 ,
116+ }
117+ )
98118
99119 return results
100120
101121
102122# ─── MoE layer benchmarks (varying batch → varying experts) ────────────────
103123
124+
104125def expected_unique_experts (batch_size , total_experts , top_k ):
105126 p_miss = (1 - top_k / total_experts ) ** batch_size
106127 return total_experts * (1 - p_miss )
@@ -124,19 +145,27 @@ def bench_moe_layer(K_dim, N, k, codebook, num_experts, M_per_expert):
124145 B_absmax_all = torch .cat (absmax_list )
125146
126147 # Build activations
127- A_list = [torch .randn (M_per_expert , K_dim , dtype = torch .float16 , device = "cuda" )
128- for _ in range (num_experts )]
148+ A_list = [torch .randn (M_per_expert , K_dim , dtype = torch .float16 , device = "cuda" ) for _ in range (num_experts )]
129149 offsets = [0 ]
130150 for i in range (num_experts ):
131151 offsets .append (offsets [- 1 ] + M_per_expert )
132152 A_concat = torch .cat (A_list )
133153 expert_offsets = torch .tensor (offsets , dtype = torch .int32 , device = "cuda" )
134154
135155 # 1. Grouped kbit GEMM
136- t_grouped = bench (lambda : torch .ops .bitsandbytes .kbit_grouped_gemm (
137- A_concat , B_packed_all , B_absmax_all , codebook ,
138- expert_offsets , K_dim , N_padded , k , num_experts ,
139- ))
156+ t_grouped = bench (
157+ lambda : torch .ops .bitsandbytes .kbit_grouped_gemm (
158+ A_concat ,
159+ B_packed_all ,
160+ B_absmax_all ,
161+ codebook ,
162+ expert_offsets ,
163+ K_dim ,
164+ N_padded ,
165+ k ,
166+ num_experts ,
167+ )
168+ )
140169
141170 # 2. cuBLAS bmm
142171 A_batched = torch .stack (A_list , dim = 0 )
@@ -149,6 +178,7 @@ def bench_moe_layer(K_dim, N, k, codebook, num_experts, M_per_expert):
149178
150179# ─── Main ──────────────────────────────────────────────────────────────────
151180
181+
152182def main ():
153183 k = 4
154184 codebook = create_normal_float_codebook (k ).cuda ()
@@ -162,7 +192,7 @@ def main():
162192 (2048 , 5120 , "dense gate/up" ),
163193 (5120 , 2048 , "dense down" ),
164194 (2048 , 4096 , "Q proj" ),
165- (2048 , 512 , "KV proj" ),
195+ (2048 , 512 , "KV proj" ),
166196 (4096 , 2048 , "O proj" ),
167197 ],
168198 "GLM4.7" : [
@@ -173,9 +203,9 @@ def main():
173203
174204 M_values = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]
175205
176- print (f"{ '=' * 100 } " )
206+ print (f"{ '=' * 100 } " )
177207 print (f" Part 1: Dense Layer Crossover (K={ k } , fused kbit vs dequant+cuBLAS vs cuBLAS)" )
178- print (f"{ '=' * 100 } " )
208+ print (f"{ '=' * 100 } " )
179209 print ()
180210
181211 # Store results for Part 3
@@ -188,8 +218,10 @@ def main():
188218 N_padded = ((N + 127 ) // 128 ) * 128
189219 print (f" { layer_name } ({ K_dim } x { N_padded } ):" )
190220
191- hdr = (f" { 'M' :>4} | { 'fused' :>8} { 'cuBLAS' :>8} { 'dq+mm' :>8} "
192- f"{ 'dq only' :>8} | { 'fused/cub' :>9} { 'dq+mm/cub' :>9} { 'best' :>12} " )
221+ hdr = (
222+ f" { 'M' :>4} | { 'fused' :>8} { 'cuBLAS' :>8} { 'dq+mm' :>8} "
223+ f"{ 'dq only' :>8} | { 'fused/cub' :>9} { 'dq+mm/cub' :>9} { 'best' :>12} "
224+ )
193225 print (hdr )
194226 print (" " + "-" * (len (hdr ) - 4 ))
195227
@@ -203,20 +235,22 @@ def main():
203235 best_kbit = min (r ["fused_us" ], r ["dq_mm_us" ])
204236 best_ratio = r ["cublas_us" ] / best_kbit
205237 best_label = "fused" if r ["fused_us" ] <= r ["dq_mm_us" ] else "dq+mm"
206- print (f" { r ['M' ]:4d} | { r ['fused_us' ]:7.0f} us { r ['cublas_us' ]:7.0f} us "
207- f"{ r ['dq_mm_us' ]:7.0f} us { r ['dq_only_us' ]:7.0f} us | "
208- f"{ fused_ratio :8.2f} x { dq_ratio :8.2f} x "
209- f"{ best_ratio :5.2f} x ({ best_label } )" )
238+ print (
239+ f" { r ['M' ]:4d} | { r ['fused_us' ]:7.0f} us { r ['cublas_us' ]:7.0f} us "
240+ f"{ r ['dq_mm_us' ]:7.0f} us { r ['dq_only_us' ]:7.0f} us | "
241+ f"{ fused_ratio :8.2f} x { dq_ratio :8.2f} x "
242+ f"{ best_ratio :5.2f} x ({ best_label } )"
243+ )
210244 print ()
211245 print ()
212246
213247 # ════════════════════════════════════════════════════════════════════════
214248 # Part 2: MoE layer performance at realistic batch sizes
215249 # ════════════════════════════════════════════════════════════════════════
216250
217- print (f"{ '=' * 100 } " )
218- print (f " Part 2: MoE Expert Layers (grouped kbit GEMM vs cuBLAS bmm)" )
219- print (f"{ '=' * 100 } " )
251+ print (f"{ '=' * 100 } " )
252+ print (" Part 2: MoE Expert Layers (grouped kbit GEMM vs cuBLAS bmm)" )
253+ print (f"{ '=' * 100 } " )
220254 print ()
221255
222256 moe_configs = {
@@ -263,9 +297,7 @@ def main():
263297 parts_str = []
264298
265299 for K_dim , N , name in shapes :
266- t_grp , t_bmm = bench_moe_layer (
267- K_dim , N , k , codebook , num_active_int , M_per_expert
268- )
300+ t_grp , t_bmm = bench_moe_layer (K_dim , N , k , codebook , num_active_int , M_per_expert )
269301 total_grp += t_grp
270302 total_bmm += t_bmm
271303 ratio = t_bmm / t_grp
@@ -286,9 +318,9 @@ def main():
286318 # Part 3: Full model speedup per batch size
287319 # ════════════════════════════════════════════════════════════════════════
288320
289- print (f"{ '=' * 100 } " )
290- print (f " Part 3: Full Model Speedup (all layers, per batch size)" )
291- print (f"{ '=' * 100 } " )
321+ print (f"{ '=' * 100 } " )
322+ print (" Part 3: Full Model Speedup (all layers, per batch size)" )
323+ print (f"{ '=' * 100 } " )
292324 print ()
293325 print (" Strategy: for each layer, pick the fastest kbit approach (fused or dq+cuBLAS)" )
294326 print (" and compare total time against cuBLAS fp16 (no quantization)." )
@@ -302,7 +334,7 @@ def main():
302334 "Qwen3" : {
303335 "dense" : [
304336 (2048 , 4096 , "Q proj" , 1 ),
305- (2048 , 512 , "KV proj" , 1 ),
337+ (2048 , 512 , "KV proj" , 1 ),
306338 (4096 , 2048 , "O proj" , 1 ),
307339 (2048 , 5120 , "dense gate/up" , 1 ),
308340 (5120 , 2048 , "dense down" , 1 ),
@@ -317,7 +349,7 @@ def main():
317349 (10240 , 2048 , "shared down" , 1 ),
318350 # Attention projections (estimated, hidden=2048)
319351 (2048 , 2048 , "Q proj" , 1 ),
320- (2048 , 512 , "KV proj" , 1 ),
352+ (2048 , 512 , "KV proj" , 1 ),
321353 (2048 , 2048 , "O proj" , 1 ),
322354 ],
323355 "moe_shapes" : ["routed gate/up" , "routed down" ],
@@ -330,7 +362,7 @@ def main():
330362 # (they weren't in Part 1). Do it now.
331363 glm_attn_shapes = [
332364 (2048 , 2048 , "Q proj" ),
333- (2048 , 512 , "KV proj" ),
365+ (2048 , 512 , "KV proj" ),
334366 (2048 , 2048 , "O proj" ),
335367 ]
336368 for K_dim , N , layer_name in glm_attn_shapes :
@@ -340,14 +372,16 @@ def main():
340372 dense_crossover_data [key ] = results
341373
342374 for model_name , cfg in model_layers .items ():
343- print (f"{ '─' * 80 } " )
375+ print (f"{ '─' * 80 } " )
344376 print (f" { model_name } " )
345- print (f"{ '─' * 80 } " )
377+ print (f"{ '─' * 80 } " )
346378 print ()
347379
348- hdr = (f" { 'batch' :>5} | { 'dense kbit' :>10} { 'dense cub' :>10} "
349- f"{ 'MoE kbit' :>10} { 'MoE cub' :>10} | "
350- f"{ 'total kbit' :>10} { 'total cub' :>10} { 'speedup' :>8} " )
380+ hdr = (
381+ f" { 'batch' :>5} | { 'dense kbit' :>10} { 'dense cub' :>10} "
382+ f"{ 'MoE kbit' :>10} { 'MoE cub' :>10} | "
383+ f"{ 'total kbit' :>10} { 'total cub' :>10} { 'speedup' :>8} "
384+ )
351385 print (hdr )
352386 print (" " + "-" * (len (hdr ) - 2 ))
353387
@@ -401,19 +435,21 @@ def main():
401435 total_cublas = total_dense_cublas_us + total_moe_cublas_us
402436 speedup = total_cublas / total_kbit if total_kbit > 0 else 0
403437
404- print (f" { bs :5d} | { total_dense_kbit_us :9.0f} us { total_dense_cublas_us :9.0f} us "
405- f"{ total_moe_kbit_us :9.0f} us { total_moe_cublas_us :9.0f} us | "
406- f"{ total_kbit :9.0f} us { total_cublas :9.0f} us { speedup :7.2f} x" )
438+ print (
439+ f" { bs :5d} | { total_dense_kbit_us :9.0f} us { total_dense_cublas_us :9.0f} us "
440+ f"{ total_moe_kbit_us :9.0f} us { total_moe_cublas_us :9.0f} us | "
441+ f"{ total_kbit :9.0f} us { total_cublas :9.0f} us { speedup :7.2f} x"
442+ )
407443
408444 print ()
409445
410446 # ════════════════════════════════════════════════════════════════════════
411447 # Part 4: Projected speedup with scalar kernel (theoretical)
412448 # ════════════════════════════════════════════════════════════════════════
413449
414- print (f"{ '=' * 100 } " )
415- print (f " Part 4: Projected Model Speedup WITH Scalar Kernel (theoretical)" )
416- print (f"{ '=' * 100 } " )
450+ print (f"{ '=' * 100 } " )
451+ print (" Part 4: Projected Model Speedup WITH Scalar Kernel (theoretical)" )
452+ print (f"{ '=' * 100 } " )
417453 print ()
418454 print (" Uses 1.8x overhead factor for scalar kernel estimate at M<=4." )
419455 print (" Dense layers at M<=4: scalar estimate instead of fused GEMM." )
@@ -447,14 +483,16 @@ def scalar_estimate_us(K_dim, N, k, num_experts, M_per_expert):
447483 total_exp = moe_cfg ["total_experts" ]
448484 top_k_val = moe_cfg ["top_k" ]
449485
450- print (f"{ '─' * 80 } " )
486+ print (f"{ '─' * 80 } " )
451487 print (f" { model_name } " )
452- print (f"{ '─' * 80 } " )
488+ print (f"{ '─' * 80 } " )
453489 print ()
454490
455- hdr = (f" { 'batch' :>5} | { 'dense kbit' :>10} { 'dense cub' :>10} "
456- f"{ 'MoE kbit' :>10} { 'MoE cub' :>10} | "
457- f"{ 'total kbit' :>10} { 'total cub' :>10} { 'speedup' :>8} " )
491+ hdr = (
492+ f" { 'batch' :>5} | { 'dense kbit' :>10} { 'dense cub' :>10} "
493+ f"{ 'MoE kbit' :>10} { 'MoE cub' :>10} | "
494+ f"{ 'total kbit' :>10} { 'total cub' :>10} { 'speedup' :>8} "
495+ )
458496 print (hdr )
459497 print (" " + "-" * (len (hdr ) - 2 ))
460498
@@ -465,7 +503,7 @@ def scalar_estimate_us(K_dim, N, k, num_experts, M_per_expert):
465503 total_invocations = bs * top_k_val
466504 M_per_expert = max (1 , round (total_invocations / num_active ))
467505
468- use_scalar = ( bs <= 4 )
506+ use_scalar = bs <= 4
469507
470508 # --- Dense layers ---
471509 total_dense_kbit_us = 0
@@ -497,9 +535,7 @@ def scalar_estimate_us(K_dim, N, k, num_experts, M_per_expert):
497535 N_moe = [s [1 ] for s in moe_cfg ["shapes" ] if s [2 ] == moe_name ][0 ]
498536
499537 if use_scalar :
500- t_scalar = scalar_estimate_us (
501- K_dim_moe , N_moe , k , num_active_int , M_per_expert
502- )
538+ t_scalar = scalar_estimate_us (K_dim_moe , N_moe , k , num_active_int , M_per_expert )
503539 t_kbit = t_scalar
504540 else :
505541 key = (model_name , moe_name , bs )
@@ -524,9 +560,11 @@ def scalar_estimate_us(K_dim, N, k, num_experts, M_per_expert):
524560 speedup = total_cublas / total_kbit if total_kbit > 0 else 0
525561
526562 marker = " ← scalar" if use_scalar else ""
527- print (f" { bs :5d} | { total_dense_kbit_us :9.0f} us { total_dense_cublas_us :9.0f} us "
528- f"{ total_moe_kbit_us :9.0f} us { total_moe_cublas_us :9.0f} us | "
529- f"{ total_kbit :9.0f} us { total_cublas :9.0f} us { speedup :7.2f} x{ marker } " )
563+ print (
564+ f" { bs :5d} | { total_dense_kbit_us :9.0f} us { total_dense_cublas_us :9.0f} us "
565+ f"{ total_moe_kbit_us :9.0f} us { total_moe_cublas_us :9.0f} us | "
566+ f"{ total_kbit :9.0f} us { total_cublas :9.0f} us { speedup :7.2f} x{ marker } "
567+ )
530568
531569 print ()
532570
0 commit comments