1- """Benchmark for Hadamard rotation kernel and full kbit pipeline.
1+ """Benchmark for Hadamard rotation kernel and kbit M=1 pipeline.
22
33Measures:
4- 1. Rotation standalone: all block sizes × Qwen3 K values × M=1,4
5- 2. Full pipeline (rotate + kbit_scalar_gemv_tiled): Qwen3 dense shapes at M=1, k=2,3,4
4+ 1. Rotation standalone: all block sizes x Qwen3 K values x M=1,4
5+ 2. Full pipeline (rotate + kbit_scalar_gemv_tiled): Qwen3 dense shapes at M=1, k=2..5
663. cuBLAS FP16 baseline: same shapes
774. Speedup table: pipeline vs cuBLAS
88
9- All timing via CUDA graph capture + replay for clean kernel-only measurements.
9+ Timing methodology:
10+ CUDA graph capture + batched replay. Each measurement replays the graph
11+ INNER times within a single event-timed region, then divides. This
12+ amortizes the ~14 us per-replay overhead down to negligible levels,
13+ revealing true kernel execution times. Median of OUTER measurements.
14+
15+ Usage:
16+ python benchmarks/bench_hadamard.py
17+ python benchmarks/bench_hadamard.py --inner 1000 --outer 30 # higher accuracy
1018"""
1119
20+ import argparse
1221import sys
1322
1423import torch
2231 quantize_kbit ,
2332)
2433
25- BLOCKSIZE = 32
26- WARMUP = 50
27- ITERS = 200
34+ ROTATION_BLOCK_SIZE = 64
2835
2936
3037def create_normal_float_codebook (k : int ) -> torch .Tensor :
@@ -35,14 +42,18 @@ def create_normal_float_codebook(k: int) -> torch.Tensor:
3542 return values .cuda ()
3643
3744
38- def bench_graph (fn , warmup = WARMUP , iters = ITERS ):
39- """Time a function using CUDA graph capture + replay. Returns median time in us."""
40- # Warm up on default stream
41- for _ in range (warmup ):
45+ def bench (fn , inner : int , outer : int ) -> float :
46+ """Batched CUDA graph replay timing. Returns median us per iteration.
47+
48+ Captures fn into a CUDA graph, then replays it `inner` times within a
49+ single CUDA event pair. The per-replay overhead (~14 us on RTX 4090)
50+ is amortized to ~14/inner us per iteration. Takes the median of `outer`
51+ such measurements.
52+ """
53+ for _ in range (30 ):
4254 fn ()
4355 torch .cuda .synchronize ()
4456
45- # Capture graph
4657 s = torch .cuda .Stream ()
4758 s .wait_stream (torch .cuda .current_stream ())
4859 with torch .cuda .stream (s ):
@@ -55,27 +66,34 @@ def bench_graph(fn, warmup=WARMUP, iters=ITERS):
5566 fn ()
5667 torch .cuda .synchronize ()
5768
58- # Warm up replay
59- for _ in range (10 ):
69+ for _ in range (50 ):
6070 g .replay ()
6171 torch .cuda .synchronize ()
6272
63- # Time replay
6473 times = []
65- for _ in range (iters ):
74+ for _ in range (outer ):
6675 start = torch .cuda .Event (enable_timing = True )
6776 end = torch .cuda .Event (enable_timing = True )
6877 start .record ()
69- g .replay ()
78+ for _ in range (inner ):
79+ g .replay ()
7080 end .record ()
7181 torch .cuda .synchronize ()
72- times .append (start .elapsed_time (end ) * 1000 ) # ms -> us
73-
82+ times .append (start .elapsed_time (end ) * 1000 / inner ) # ms -> us/iter
7483 times .sort ()
75- return times [len (times ) // 2 ] # median
84+ return times [len (times ) // 2 ]
85+
86+
87+ def prepare_kbit_weights (K_dim , N , k ):
88+ """Quantize random weights and repack for tiled access."""
89+ W = torch .randn (N , K_dim , dtype = torch .float16 , device = "cuda" )
90+ codebook = create_normal_float_codebook (k )
91+ packed , absmax , _ = quantize_kbit (W , k = k , codebook = codebook )
92+ packed_tiled , absmax_tiled = torch .ops .bitsandbytes .repack_kbit (packed , absmax , K_dim , N , k )
93+ return packed_tiled , absmax_tiled , codebook
7694
7795
78- def bench_rotation_standalone ():
96+ def bench_rotation_standalone (inner , outer ):
7997 """Benchmark rotation kernel standalone across block sizes and shapes."""
8098 print ("=" * 70 )
8199 print ("1. ROTATION STANDALONE" )
@@ -91,31 +109,20 @@ def bench_rotation_standalone():
91109 for K in k_values :
92110 for bs in block_sizes :
93111 A = torch .randn (M , K , dtype = torch .float16 , device = "cuda" )
94- t = bench_graph (lambda : hadamard_rotate (A , block_size = bs ))
95- # BW: read + write = 2 * numel * 2 bytes (fp16)
112+ t = bench (lambda : hadamard_rotate (A , block_size = bs ), inner , outer )
96113 bw = 2 * A .numel () * 2 / (t / 1e6 ) / 1e9
97- print (f"{ M :>4} { K :>6} { bs :>4} { t :>10.2f } { bw :>10.1f} " )
114+ print (f"{ M :>4} { K :>6} { bs :>4} { t :>10.3f } { bw :>10.1f} " )
98115 print ()
99116
100117
101- def prepare_kbit_weights (K_dim , N , k ):
102- """Quantize random weights and repack for tiled access."""
103- W = torch .randn (N , K_dim , dtype = torch .float16 , device = "cuda" )
104- codebook = create_normal_float_codebook (k )
105- packed , absmax , _ = quantize_kbit (W , k = k , codebook = codebook )
106- packed_tiled , absmax_tiled = torch .ops .bitsandbytes .repack_kbit (packed , absmax , K_dim , N , k )
107- return packed_tiled , absmax_tiled , codebook
108-
109-
110- def bench_pipeline ():
118+ def bench_pipeline (inner , outer ):
111119 """Benchmark full pipeline: rotate(A) + kbit_scalar_gemv."""
112120 print ("=" * 70 )
113121 print ("2. FULL PIPELINE: rotate + kbit_scalar_gemv_tiled" )
114122 print ("=" * 70 )
115123 print (f"{ 'M' :>4} { 'K' :>6} { 'N' :>6} { 'k' :>2} { 'Rotate(us)' :>11} { 'GEMV(us)' :>9} { 'Total(us)' :>10} { 'TFLOPS' :>7} " )
116124 print ("-" * 65 )
117125
118- # Qwen3-Coder-Next 70B dense shapes
119126 shapes = [
120127 (1 , 2048 , 5120 , "gate/up" ),
121128 (1 , 5120 , 2048 , "down" ),
@@ -126,42 +133,41 @@ def bench_pipeline():
126133 (4 , 5120 , 2048 , "down M=4" ),
127134 ]
128135
129- for k in [2 , 3 , 4 ]:
136+ for k in [2 , 3 , 4 , 5 ]:
130137 print (f"\n --- k={ k } ---" )
131138 for M , K_dim , N , label in shapes :
132139 packed_tiled , absmax_tiled , codebook = prepare_kbit_weights (K_dim , N , k )
133140 A = torch .randn (M , K_dim , dtype = torch .float16 , device = "cuda" )
134141
135- # Benchmark rotation alone
136142 A_copy = A .clone ()
137- t_rot = bench_graph (lambda : hadamard_rotate (A_copy , block_size = 64 ) )
143+ t_rot = bench (lambda : hadamard_rotate (A_copy , block_size = ROTATION_BLOCK_SIZE ), inner , outer )
138144
139- # Benchmark GEMV alone (tiled layout, pre-allocated output)
140145 out = torch .zeros (M , N , dtype = torch .float16 , device = "cuda" )
141- t_gemv = bench_graph (
146+ t_gemv = bench (
142147 lambda : torch .ops .bitsandbytes .kbit_scalar_gemv_tiled_ (
143148 A , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out
144- )
149+ ),
150+ inner ,
151+ outer ,
145152 )
146153
147- # Benchmark combined
148154 def pipeline ():
149- hadamard_rotate (A_copy , block_size = 64 )
155+ hadamard_rotate (A_copy , block_size = ROTATION_BLOCK_SIZE )
150156 torch .ops .bitsandbytes .kbit_scalar_gemv_tiled_ (
151157 A_copy , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out
152158 )
153159
154- t_total = bench_graph (pipeline )
160+ t_total = bench (pipeline , inner , outer )
155161
156162 flops = 2 * M * K_dim * N
157163 tflops = flops / (t_total / 1e6 ) / 1e12
158164 print (
159- f"{ M :>4} { K_dim :>6} { N :>6} { k :>2} { t_rot :>11.2f } { t_gemv :>9.2f } "
160- f"{ t_total :>10.2f } { tflops :>7.3f} { label } "
165+ f"{ M :>4} { K_dim :>6} { N :>6} { k :>2} { t_rot :>11.3f } { t_gemv :>9.3f } "
166+ f"{ t_total :>10.3f } { tflops :>7.3f} { label } "
161167 )
162168
163169
164- def bench_cublas_baseline ():
170+ def bench_cublas_baseline (inner , outer ):
165171 """Benchmark cuBLAS FP16 GEMM for the same shapes."""
166172 print ("\n " + "=" * 70 )
167173 print ("3. cuBLAS FP16 BASELINE" )
@@ -184,16 +190,16 @@ def bench_cublas_baseline():
184190 W = torch .randn (N , K_dim , dtype = torch .float16 , device = "cuda" )
185191 out = torch .empty (M , N , dtype = torch .float16 , device = "cuda" )
186192
187- t = bench_graph (lambda : torch .mm (A , W .t (), out = out ))
193+ t = bench (lambda : torch .mm (A , W .t (), out = out ), inner , outer )
188194 flops = 2 * M * K_dim * N
189195 tflops = flops / (t / 1e6 ) / 1e12
190- print (f"{ M :>4} { K_dim :>6} { N :>6} { t :>9.2f } { tflops :>7.3f} " )
196+ print (f"{ M :>4} { K_dim :>6} { N :>6} { t :>9.3f } { tflops :>7.3f} " )
191197
192198
193- def bench_speedup_table ():
199+ def bench_speedup_table (inner , outer ):
194200 """Print a speedup comparison table: pipeline vs cuBLAS."""
195201 print ("\n " + "=" * 70 )
196- print ("4. SPEEDUP TABLE: kbit pipeline vs cuBLAS FP16" )
202+ print ("4. SPEEDUP TABLE: Rot + kbit GEMV vs cuBLAS FP16" )
197203 print ("=" * 70 )
198204
199205 shapes = [
@@ -208,7 +214,7 @@ def bench_speedup_table():
208214 print (f"{ 'Shape' :>20} { 'k' :>2} { 'Pipeline(us)' :>13} { 'cuBLAS(us)' :>11} { 'Speedup' :>8} " )
209215 print ("-" * 65 )
210216
211- for k in [2 , 3 , 4 ]:
217+ for k in [2 , 3 , 4 , 5 ]:
212218 print (f"\n --- k={ k } ---" )
213219 for M , K_dim , N , label in shapes :
214220 packed_tiled , absmax_tiled , codebook = prepare_kbit_weights (K_dim , N , k )
@@ -217,40 +223,36 @@ def bench_speedup_table():
217223 out = torch .zeros (M , N , dtype = torch .float16 , device = "cuda" )
218224 A_copy = A .clone ()
219225
220- # Pipeline: rotate + GEMV
221226 def pipeline ():
222- hadamard_rotate (A_copy , block_size = 64 )
227+ hadamard_rotate (A_copy , block_size = ROTATION_BLOCK_SIZE )
223228 torch .ops .bitsandbytes .kbit_scalar_gemv_tiled_ (
224229 A_copy , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out
225230 )
226231
227- t_pipe = bench_graph (pipeline )
228-
229- # cuBLAS baseline
230- t_cublas = bench_graph (lambda : torch .mm (A , W .t (), out = out ))
232+ t_pipe = bench (pipeline , inner , outer )
233+ t_cublas = bench (lambda : torch .mm (A , W .t (), out = out ), inner , outer )
231234
232235 speedup = t_cublas / t_pipe
233236 shape_str = f"{ M } x{ K_dim } x{ N } "
234- print (f"{ shape_str :>20} { k :>2} { t_pipe :>13.2f } { t_cublas :>11.2f } { speedup :>7.2f} x { label } " )
237+ print (f"{ shape_str :>20} { k :>2} { t_pipe :>13.3f } { t_cublas :>11.3f } { speedup :>7.2f} x { label } " )
235238
236239
237- def bench_cuda_graph_capture ():
238- """Verify that all benchmarks above were graph-captured (implicit from bench_graph).
239- This just confirms the pipeline captures as a single graph explicitly."""
240- print ("\n " + "=" * 70 )
241- print ("5. CUDA GRAPH CAPTURE VERIFICATION" )
242- print ("=" * 70 )
243- print ("All benchmarks above used CUDA graph capture + replay for timing." )
244- print ("If they produced numbers, graph capture succeeded for all operations." )
240+ def main ():
241+ parser = argparse .ArgumentParser (description = "Hadamard rotation + kbit M=1 pipeline benchmark" )
242+ parser .add_argument ("--inner" , type = int , default = 500 , help = "Graph replays per measurement (default: 500)" )
243+ parser .add_argument ("--outer" , type = int , default = 15 , help = "Measurements per benchmark (default: 15)" )
244+ args = parser .parse_args ()
245245
246-
247- if __name__ == "__main__" :
248246 print (f"GPU: { torch .cuda .get_device_name (0 )} " )
249247 print (f"CUDA: { torch .version .cuda } " )
248+ print (f"Timing: batched graph replay ({ args .inner } replays/measurement, median of { args .outer } )" )
250249 print ()
251250
252- bench_rotation_standalone ()
253- bench_pipeline ()
254- bench_cublas_baseline ()
255- bench_speedup_table ()
256- bench_cuda_graph_capture ()
251+ bench_rotation_standalone (args .inner , args .outer )
252+ bench_pipeline (args .inner , args .outer )
253+ bench_cublas_baseline (args .inner , args .outer )
254+ bench_speedup_table (args .inner , args .outer )
255+
256+
257+ if __name__ == "__main__" :
258+ main ()
0 commit comments