2929# Test configurations
3030
3131BLOCK_SIZE = [[128 , 128 ]]
32- M_list = [1 , 7 , 83 , 512 , 2048 ]
33- N_list = [128 , 512 , 1024 , 4096 , 7748 , 13824 ]
34- K_list = [256 , 4096 , 5120 , 3884 , 13824 ]
32+ M_list = [1 , 7 ]#, 83, 512, 2048]
33+ N_list = [128 , 512 ]#, 1024, 4096, 7748, 13824]
34+ K_list = [256 , 4096 ]#, 5120, 3884, 13824]
35+ _WEIGHT_DTYPES = [InfiniDtype .I8 ]
36+
3537SEEDS = 0
3638
3739def to_iter (x ):
@@ -44,12 +46,13 @@ def to_iter(x):
4446 to_iter (K_list ),
4547 to_iter (N_list ),
4648 to_iter (BLOCK_SIZE ),
49+ to_iter (_WEIGHT_DTYPES ),
4750 )
4851)
4952
5053
5154# Data types used for testing
52- _TENSOR_DTYPES = [InfiniDtype .F16 ]
55+ _TENSOR_DTYPES = [InfiniDtype .BF16 , InfiniDtype . F16 ]
5356
5457
5558DEBUG = False
@@ -108,164 +111,82 @@ def native_w8a16_block_int8_matmul(
108111 return C
109112
110113
111- def native_w8a16_block_fp8_matmul (
112- A ,
113- B ,
114- Bs ,
115- block_size ,
116- output_dtype : torch .float16 ,
117- ) -> torch .Tensor :
118- return native_w8a16_block_int8_matmul (A , B , Bs , block_size , output_dtype )
119-
120-
121- def test_w8a8_block_fp8_matmul (M , N , K , block_size , out_dtype , seed ):
122- torch .manual_seed (seed )
123- factor_for_scale = 1e-2
124- fp8_info = torch .finfo (torch .float8_e4m3fn )
125- fp8_max , fp8_min = fp8_info .max , fp8_info .min
126-
127- A_fp32 = (torch .rand (M , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
128- #A_fp32 = A_fp32.fill_(1)
129- A_fp8 = A_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
130-
131- B_fp32 = (torch .rand (N , K , dtype = torch .float32 ) - 0.5 ) * 2 * fp8_max
132- #B_fp32 = B_fp32.fill_(1)
133- B_fp8 = B_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
134-
135- block_n , block_k = block_size [0 ], block_size [1 ]
136- n_tiles = (N + block_n - 1 ) // block_n
137- k_tiles = (K + block_k - 1 ) // block_k
138-
139- As = torch .rand (M , k_tiles , dtype = torch .float32 ) * factor_for_scale
140- #As = As.fill_(1)
141- Bs = torch .rand (n_tiles , k_tiles , dtype = torch .float32 ) * factor_for_scale
142- #Bs = Bs.fill_(1.5)
143- #ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
144- # out_dtype)
145- ref_out = native_w8a16_block_fp8_matmul (A_fp32 .to (torch .bfloat16 ), B_fp8 , Bs , block_size , out_dtype )
146- #out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
147-
148- B_fp8_T = B_fp8 .t ()
149- #print('B_fp8_T', B_fp8_T.size(), B_fp8_T)
150-
151- Bs_T = Bs
152- quant_type = 3
153- bit = 8
154- return ref_out , A_fp32 .to (torch .bfloat16 ), B_fp8_T , Bs_T , Bs_T , quant_type , bit
155-
156-
157- def test_w8a8_block_int8_matmul (M , N , K , block_size , out_dtype , seed ):
158- torch .manual_seed (seed )
159- factor_for_scale = 1e-2
160- int8_info = torch .iinfo (torch .int8 )
161- int8_max , int8_min = int8_info .max , int8_info .min
162-
163- A_fpb16 = torch .rand (M , K , dtype = torch .float32 ) / 10
164-
165-
166- #A_fp32 = A_fp32.fill_(1)
167- #A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
168-
169- B_fp32 = (torch .rand (N , K , dtype = torch .float32 ) - 0.5 ) * 2 * int8_max
170- #B_fp32 = B_fp32.fill_(1)
171- B_int8 = B_fp32 .clamp (min = int8_min , max = int8_max ).to (torch .int8 )
172-
173- block_n , block_k = block_size [0 ], block_size [1 ]
174- n_tiles = (N + block_n - 1 ) // block_n
175- k_tiles = (K + block_k - 1 ) // block_k
176-
177- A_fpb16 = A_fpb16 .to (torch .float16 )
178-
179- #As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
180- #As = As.fill_(1)
181- Bs = torch .rand (n_tiles , k_tiles , dtype = torch .float32 ) * factor_for_scale
182- #Bs = Bs.fill_(1.5)
183- #ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
184-
185- ref_out = native_w8a16_block_fp8_matmul (A_fpb16 , B_int8 , Bs , block_size , out_dtype )
186- #a_q, a_s = native_per_token_group_quant_int8(A_fpb16, block_k)
187- #ref_out = native_w8a8_block_int8_matmul(a_q, B_int8, a_s, Bs, block_size, output_dtype=A_fpb16.dtype)
188- ##out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
189- #print('Bs', Bs.size(), Bs.dtype)
190- quant_type = 3
191- bit = 8
192- return ref_out , A_fpb16 , B_int8 , Bs , Bs , quant_type , bit
193-
194-
195- def test_int8 (
114+ def test (
196115 handle ,
197116 device ,
198117 M ,
199118 K ,
200119 N ,
201120 block_size ,
121+ weight_dtype = InfiniDtype .I8 ,
202122 dtype = InfiniDtype .BF16 ,
203123 sync = None ,
204124):
205125
206126 print (
207- f"Testing int8 Gptq Qyblas Gemm on { InfiniDeviceNames [device ]} with M-K-N:{ M , K , N } , block_size:{ block_size } , dtype:{ InfiniDtypeNames [dtype ]} "
127+ f"Testing int8 Gptq Qyblas Gemm on { InfiniDeviceNames [device ]} with M-K-N:{ M , K , N } , block_size:{ block_size } , weight dtype: { InfiniDtypeNames [ weight_dtype ] } , dtype:{ InfiniDtypeNames [dtype ]} "
208128 )
209- out_dtype = to_torch_dtype (dtype )
210- ans , a , b_orig , b_scales , b_zeros , quant_type , bit = test_w8a8_block_int8_matmul (M , N , K , block_size , out_dtype , SEEDS )
211- b = b_orig .t ()
212-
129+ quant_type = 3
130+ bit = 8
131+
132+ int8_info = torch .iinfo (torch .int8 )
133+ int8_max , int8_min = int8_info .max , int8_info .min
134+
135+ block_n , block_k = block_size [0 ], block_size [1 ]
136+ n_tiles = (N + block_n - 1 ) // block_n
137+ k_tiles = (K + block_k - 1 ) // block_k
138+
213139 A = TestTensor (
214- a .shape ,
215- a .stride (),
216- InfiniDtype .F16 ,
217- device ,
218- mode = "manual" ,
219- set_tensor = a ,
220- )
221- B_orig = TestTensor (
222- b_orig .shape ,
223- b_orig .stride (),
224- InfiniDtype .I8 ,
225- device ,
226- mode = "manual" ,
227- set_tensor = b_orig ,
228- )
229- B = TestTensor (
230- b .shape ,
231- b .stride (),
232- InfiniDtype .I8 ,
140+ (M , K ),
141+ None ,
142+ dtype ,
233143 device ,
234- mode = "manual" ,
235- set_tensor = b ,
236144 )
145+ if weight_dtype == InfiniDtype .I8 :
146+ B_orig = TestTensor (
147+ (N , K ),
148+ None ,
149+ weight_dtype ,
150+ device ,
151+ randint_low = int8_min ,
152+ randint_high = int8_max ,
153+ )
154+ B_torch = B_orig .torch_tensor ().t ()
155+ B = TestTensor (
156+ (K , N ),
157+ B_torch .stride (),
158+ weight_dtype ,
159+ device ,
160+ mode = "manual" ,
161+ set_tensor = B_torch ,
162+ )
163+
237164 b_scales = TestTensor (
238- b_scales . shape ,
239- b_scales . stride () ,
165+ ( n_tiles , k_tiles ) ,
166+ None ,
240167 InfiniDtype .F32 ,
241168 device ,
242- mode = "manual" ,
243- set_tensor = b_scales ,
244169 )
170+
245171 b_zeros = TestTensor (
246- b_zeros . shape ,
247- b_zeros . stride () ,
172+ ( n_tiles , k_tiles ) ,
173+ None ,
248174 InfiniDtype .F32 ,
249175 device ,
250- mode = "manual" ,
251- set_tensor = b_zeros ,
176+ mode = "zeros" ,
252177 )
178+
253179 out = TestTensor (
254- ans . shape ,
180+ ( M , N ) ,
255181 None ,
256182 dtype ,
257183 device ,
184+ mode = "zeros" ,
258185 )
259-
260- print ("a: " , A .torch_tensor ().shape , A .torch_tensor ().stride (), A .torch_tensor ().dtype )
261- print ("b: " , B .torch_tensor ().shape , B .torch_tensor ().stride (), B .torch_tensor ().dtype )
262- print ("scales: " , b_scales .torch_tensor ().shape , b_scales .torch_tensor ().dtype )
263- print ("zeros: " , b_zeros .torch_tensor ().shape , b_zeros .torch_tensor ().dtype )
264- print ("out: " , out .torch_tensor ().shape , out .torch_tensor ().dtype )
186+
265187 if sync is not None :
266188 sync ()
267189
268-
269190 descriptor = infiniopOperatorDescriptor_t ()
270191 check_error (
271192 LIBINFINIOP .infiniopCreateGptqQyblasGemmDescriptor (
@@ -278,7 +199,6 @@ def test_int8(
278199 b_zeros .descriptor ,
279200 )
280201 )
281-
282202 # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
283203
284204 for tensor in [out , A , B , b_scales , b_zeros ]:
@@ -314,31 +234,20 @@ def lib_gptq_qyblas_gemm():
314234 if sync is not None :
315235 sync ()
316236
317- tmpa = out . torch_tensor (). to ( torch . float32 ). detach (). to ( 'cpu' ). numpy (). flatten ( )
318- tmpb = ans . to ( torch . float32 ). to ( 'cpu' ). detach (). numpy (). flatten ( )
237+ out_dtype = to_torch_dtype ( dtype )
238+ ans = native_w8a16_block_int8_matmul ( A . torch_tensor (), B_orig . torch_tensor (), b_scales . torch_tensor (), block_size , out_dtype )
319239
320- atol = max (abs (tmpa - tmpb ))
321-
322- rtol = atol / (max (abs (tmpb )) + 1e-8 )
323-
324-
325- print ("absolute error:%.4e" % (atol ))
326- print ("relative error:%.4e" % (rtol ))
327- print (out .torch_tensor ().device , ans .device )
328- # print(out.torch_tensor())
329- # print(ans)
330- ans = ans .to (out .torch_tensor ().device )
331240 rel_diff = (torch .mean (
332- torch .abs (out .torch_tensor ().to (torch .float32 ) - ans .to (torch .float32 ))) /
241+ torch .abs (out .actual_tensor ().to (torch .float32 ) - ans .to (torch .float32 ))) /
333242 torch .mean (torch .abs (ans .to (torch .float32 ))))
334- print ( rel_diff )
243+
335244 assert rel_diff < 0.05
336245
337246
338247 # Profiling workflow
339248 if PROFILE :
340249 # fmt: off
341- profile_operation ("PyTorch" , lambda : native_w8a16_block_fp8_matmul (A .torch_tensor (), B_orig .torch_tensor (), b_scales .torch_tensor (), block_size , out_dtype ), device , NUM_PRERUN , NUM_ITERATIONS )
250+ profile_operation ("PyTorch" , lambda : native_w8a16_block_int8_matmul (A .torch_tensor (), B_orig .torch_tensor (), b_scales .torch_tensor (), block_size , out_dtype ), device , NUM_PRERUN , NUM_ITERATIONS )
342251 profile_operation (" lib" , lambda : lib_gptq_qyblas_gemm (), device , NUM_PRERUN , NUM_ITERATIONS )
343252 # fmt: on
344253
@@ -355,6 +264,6 @@ def lib_gptq_qyblas_gemm():
355264 NUM_ITERATIONS = args .num_iterations
356265
357266 for device in get_test_devices (args ):
358- test_operator (device , test_int8 , _TEST_CASES , _TENSOR_DTYPES )
267+ test_operator (device , test , _TEST_CASES , _TENSOR_DTYPES )
359268
360269 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments