1818sys .path .insert (0 , os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
1919
2020import torch
21+
2122from bitsandbytes .functional import create_normal_float_codebook
2223
2324WARMUP = 20
@@ -85,9 +86,7 @@ def prepare_dense_data(device):
8586 codebook = create_normal_float_codebook (k , device = device )
8687 W = torch .randn (K_dim * N , device = device , dtype = torch .float32 )
8788 packed_flat , absmax_flat = torch .ops .bitsandbytes .quantize_kbit (W , codebook , k )
88- packed_tiled , absmax_tiled = torch .ops .bitsandbytes .repack_kbit (
89- packed_flat , absmax_flat , K_dim , N , k
90- )
89+ packed_tiled , absmax_tiled = torch .ops .bitsandbytes .repack_kbit (packed_flat , absmax_flat , K_dim , N , k )
9190 data [(name , k )] = (K_dim , N , packed_flat , absmax_flat , packed_tiled , absmax_tiled , codebook )
9291 return data
9392
@@ -134,8 +133,17 @@ def bench_mma(data, m_vals, device):
134133 tile_counters = torch .zeros (m_tiles * n_tiles , dtype = torch .int32 , device = device )
135134
136135 fn = lambda : torch .ops .bitsandbytes .kbit_gemm_prod_ (
137- A , packed_tiled , absmax_tiled , codebook ,
138- K_dim , N , k , 1 , out , C_workspace , tile_counters ,
136+ A ,
137+ packed_tiled ,
138+ absmax_tiled ,
139+ codebook ,
140+ K_dim ,
141+ N ,
142+ k ,
143+ 1 ,
144+ out ,
145+ C_workspace ,
146+ tile_counters ,
139147 )
140148 avg_us = bench_kernel (fn )
141149 print (f"{ name :<8} { k :>2} { M :>2} { avg_us :>10.2f} " )
@@ -160,8 +168,14 @@ def bench_scalar(data, m_vals, device):
160168 out = torch .empty (M , N , dtype = torch .float16 , device = device )
161169
162170 fn = lambda : torch .ops .bitsandbytes .kbit_scalar_gemv_tiled_ (
163- A , packed_tiled , absmax_tiled , codebook ,
164- K_dim , N , k , out ,
171+ A ,
172+ packed_tiled ,
173+ absmax_tiled ,
174+ codebook ,
175+ K_dim ,
176+ N ,
177+ k ,
178+ out ,
165179 )
166180 avg_us = bench_kernel (fn )
167181 print (f"{ name :<8} { k :>2} { M :>2} { avg_us :>10.2f} " )
@@ -184,8 +198,16 @@ def bench_grouped(moe_data, m_vals, device):
184198
185199 # Grouped GEMM doesn't have an _ variant yet — use the allocating version
186200 fn = lambda : torch .ops .bitsandbytes .kbit_grouped_gemm (
187- A_concat , B_packed_all , B_absmax_all , codebook ,
188- expert_offsets , K_dim , N , k , NUM_EXPERTS , M ,
201+ A_concat ,
202+ B_packed_all ,
203+ B_absmax_all ,
204+ codebook ,
205+ expert_offsets ,
206+ K_dim ,
207+ N ,
208+ k ,
209+ NUM_EXPERTS ,
210+ M ,
189211 )
190212 avg_us = bench_kernel (fn )
191213 print (f"{ name :<8} { k :>2} { M :>2} { avg_us :>10.2f} " )
0 commit comments