11"""Benchmark tiled vs flat scalar GEMV with pre-allocated output buffers.
22
33Measures kernel-only time by pre-allocating all buffers before the timing loop.
4- No allocations inside the measured region — fair comparison between flat and tiled.
4+ No allocations inside the measured region — fair comparison between flat, tiled, and tiled v2 .
55
66Usage:
77 python benchmarks/bench_tiled_vs_flat.py
4040M_VALUES = [1 , 2 , 4 ]
4141
4242if args .graph :
43- print (f"{ 'shape' :<8} { 'K_dim' :>5} { 'N' :>5} { 'k' :>2} { 'M' :>2} { 'flat_us' :>8} { '±flat' :>6} { 'tiled_us' :>8} { '±tiled' :>6} { 'diff%' :>7} " )
44- print ("-" * 76 )
43+ print (
44+ f"{ 'shape' :<8} { 'K_dim' :>5} { 'N' :>5} { 'k' :>2} { 'M' :>2} "
45+ f" { 'flat_us' :>8} { '±flat' :>6} "
46+ f" { 'tiled_us' :>8} { '±tl' :>4} "
47+ f" { 'v2_us' :>8} { '±v2' :>4} "
48+ f" { 'tl/fl%' :>7} { 'v2/fl%' :>7} "
49+ )
50+ print ("-" * 100 )
4551else :
46- print (f"{ 'shape' :<8} { 'K_dim' :>5} { 'N' :>5} { 'k' :>2} { 'M' :>2} { 'flat_us' :>8} { 'tiled_us' :>8} { 'diff%' :>7} " )
47- print ("-" * 60 )
52+ print (
53+ f"{ 'shape' :<8} { 'K_dim' :>5} { 'N' :>5} { 'k' :>2} { 'M' :>2} "
54+ f" { 'flat_us' :>8} { 'tiled_us' :>8} { 'v2_us' :>8} { 'tl/fl%' :>7} { 'v2/fl%' :>7} "
55+ )
56+ print ("-" * 75 )
4857
4958for name , K_dim , N in SHAPES :
5059 for k in K_VALUES :
6372 # Pre-allocate output buffers
6473 out_flat = torch .empty (M , N , dtype = torch .float16 , device = "cuda" )
6574 out_tiled = torch .empty (M , N , dtype = torch .float16 , device = "cuda" )
75+ out_v2 = torch .empty (M , N , dtype = torch .float16 , device = "cuda" )
76+
77+ # v2 workspace
78+ n_tiles = N // 128
79+ C_workspace = torch .zeros (M , N , dtype = torch .float32 , device = "cuda" )
80+ tile_counters = torch .zeros (n_tiles , dtype = torch .int32 , device = "cuda" )
6681
6782 if args .ncu :
68- # NCU mode: single call each, profiler captures kernel time
6983 torch .ops .bitsandbytes .kbit_scalar_gemv .out (
7084 A , packed_flat , absmax_flat , codebook , K_dim , N , k , out_flat
7185 )
7286 torch .ops .bitsandbytes .kbit_scalar_gemv_tiled_ (
7387 A , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out_tiled
7488 )
75- print (f"{ name :<8} { K_dim :>5} { N :>5} { k :>2} { M :>2} { 'ncu' :>8} { 'ncu' :>8} { 'ncu' :>7} " )
89+ torch .ops .bitsandbytes .kbit_scalar_gemv_v2_ (
90+ A , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out_v2 , C_workspace , tile_counters
91+ )
92+ print (
93+ f"{ name :<8} { K_dim :>5} { N :>5} { k :>2} { M :>2} "
94+ f" { 'ncu' :>8} { 'ncu' :>8} { 'ncu' :>8} { 'ncu' :>7} { 'ncu' :>7} "
95+ )
7696 continue
7797
7898 start = torch .cuda .Event (enable_timing = True )
@@ -88,11 +108,16 @@ def call_tiled():
88108 A , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out_tiled
89109 )
90110
111+ def call_v2 ():
112+ torch .ops .bitsandbytes .kbit_scalar_gemv_v2_ (
113+ A , packed_tiled , absmax_tiled , codebook , K_dim , N , k , out_v2 , C_workspace , tile_counters
114+ )
115+
91116 if args .graph :
92117 import statistics
93118
94119 # CUDA graph replay — measures kernel-only time
95- for fn in (call_flat , call_tiled ):
120+ for fn in (call_flat , call_tiled , call_v2 ):
96121 for _ in range (3 ):
97122 fn ()
98123 torch .cuda .synchronize ()
@@ -117,38 +142,44 @@ def bench_graph(fn, trials, iters):
117142
118143 flat_us , flat_std = bench_graph (call_flat , args .trials , args .iters )
119144 tiled_us , tiled_std = bench_graph (call_tiled , args .trials , args .iters )
145+ v2_us , v2_std = bench_graph (call_v2 , args .trials , args .iters )
120146 else :
121- # CUDA events timing (includes Python dispatch overhead)
122- for _ in range (args .warmup ):
123- call_flat ()
124- torch .cuda .synchronize ()
125- start .record ()
126- for _ in range (args .iters ):
127- call_flat ()
128- end .record ()
129- torch .cuda .synchronize ()
130- flat_us = start .elapsed_time (end ) * 1000 / args .iters
147+ def bench_events ( fn ):
148+ for _ in range (args .warmup ):
149+ fn ()
150+ torch .cuda .synchronize ()
151+ start .record ()
152+ for _ in range (args .iters ):
153+ fn ()
154+ end .record ()
155+ torch .cuda .synchronize ()
156+ return start .elapsed_time (end ) * 1000 / args .iters
131157
132- for _ in range (args .warmup ):
133- call_tiled ()
134- torch .cuda .synchronize ()
135- start .record ()
136- for _ in range (args .iters ):
137- call_tiled ()
138- end .record ()
139- torch .cuda .synchronize ()
140- tiled_us = start .elapsed_time (end ) * 1000 / args .iters
158+ flat_us = bench_events (call_flat )
159+ tiled_us = bench_events (call_tiled )
160+ v2_us = bench_events (call_v2 )
141161
142- diff_pct = (tiled_us - flat_us ) / flat_us * 100
162+ tl_pct = (tiled_us - flat_us ) / flat_us * 100
163+ v2_pct = (v2_us - flat_us ) / flat_us * 100
143164 if args .graph :
144165 print (
145166 f"{ name :<8} { K_dim :>5} { N :>5} { k :>2} { M :>2} "
146- f" { flat_us :>8.1f} { flat_std :>5.1f} σ { tiled_us :>8.1f} { tiled_std :>5.1f} σ { diff_pct :>+7.1f} %"
167+ f" { flat_us :>8.1f} { flat_std :>5.1f} σ"
168+ f" { tiled_us :>8.1f} { tiled_std :>3.1f} σ"
169+ f" { v2_us :>8.1f} { v2_std :>3.1f} σ"
170+ f" { tl_pct :>+7.1f} % { v2_pct :>+7.1f} %"
147171 )
148172 else :
149- print (f"{ name :<8} { K_dim :>5} { N :>5} { k :>2} { M :>2} { flat_us :>8.1f} { tiled_us :>8.1f} { diff_pct :>+7.1f} %" )
173+ print (
174+ f"{ name :<8} { K_dim :>5} { N :>5} { k :>2} { M :>2} "
175+ f" { flat_us :>8.1f} { tiled_us :>8.1f} { v2_us :>8.1f} { tl_pct :>+7.1f} % { v2_pct :>+7.1f} %"
176+ )
150177
151178 # Correctness check (once per shape/k)
152179 assert torch .equal (out_flat , out_tiled ) or torch .allclose (out_flat , out_tiled , rtol = 0.05 , atol = 0.1 ), (
153- f"MISMATCH { name } k={ k } : max diff = { (out_flat - out_tiled ).abs ().max ().item ()} "
180+ f"MISMATCH flat vs tiled { name } k={ k } : max diff = { (out_flat - out_tiled ).abs ().max ().item ()} "
181+ )
182+ # v2 uses split-K so small FP diffs are expected
183+ assert torch .allclose (out_flat .float (), out_v2 .float (), rtol = 0.1 , atol = 1.0 ), (
184+ f"MISMATCH flat vs v2 { name } k={ k } : max diff = { (out_flat .float () - out_v2 .float ()).abs ().max ().item ()} "
154185 )
0 commit comments