@@ -51,9 +51,14 @@ def to_iter(x):
5151)
5252
5353
54+ _TEST_CASES_W4 = [(32768 , 3584 , 4608 , [128 , 128 ], InfiniDtype .U8 ),]
55+
56+
5457# Data types used for testing
5558_TENSOR_DTYPES = [InfiniDtype .BF16 , InfiniDtype .F16 ]
5659
60+ _TENSOR_DTYPES_W4 = [InfiniDtype .F16 ]
61+
5762
5863DEBUG = False
5964PROFILE = False
@@ -129,9 +134,6 @@ def test(
129134 quant_type = 3
130135 bit = 8
131136
132- int8_info = torch .iinfo (torch .int8 )
133- int8_max , int8_min = int8_info .max , int8_info .min
134-
135137 block_n , block_k = block_size [0 ], block_size [1 ]
136138 n_tiles = (N + block_n - 1 ) // block_n
137139 k_tiles = (K + block_k - 1 ) // block_k
@@ -143,23 +145,28 @@ def test(
143145 device ,
144146 )
145147 if weight_dtype == InfiniDtype .I8 :
146- B_orig = TestTensor (
147- (N , K ),
148- None ,
149- weight_dtype ,
150- device ,
151- randint_low = int8_min ,
152- randint_high = int8_max ,
153- )
154- B_torch = B_orig .torch_tensor ().t ()
155- B = TestTensor (
156- (K , N ),
157- B_torch .stride (),
158- weight_dtype ,
159- device ,
160- mode = "manual" ,
161- set_tensor = B_torch ,
162- )
148+ _info = torch .iinfo (torch .int8 )
149+ elif weight_dtype == InfiniDtype .U8 :
150+ _info = torch .iinfo (torch .uint8 )
151+ elif weight_dtype == InfiniDtype .F8 :
152+ _info = torch .iinfo (float8_e4m3fn )
153+ B_orig = TestTensor (
154+ (N , K ),
155+ None ,
156+ weight_dtype ,
157+ device ,
158+ randint_low = _info .min ,
159+ randint_high = _info .max ,
160+ )
161+ B_torch = B_orig .torch_tensor ().t ()
162+ B = TestTensor (
163+ (K , N ),
164+ B_torch .stride (),
165+ weight_dtype ,
166+ device ,
167+ mode = "manual" ,
168+ set_tensor = B_torch ,
169+ )
163170
164171 b_scales = TestTensor (
165172 (n_tiles , k_tiles ),
@@ -254,6 +261,165 @@ def lib_gptq_qyblas_gemm():
254261 check_error (LIBINFINIOP .infiniopDestroyGptqQyblasGemmDescriptor (descriptor ))
255262
256263
264+ def test_w4 (
265+ handle ,
266+ device ,
267+ M ,
268+ K ,
269+ N ,
270+ block_size ,
271+ weight_dtype = InfiniDtype .I8 ,
272+ dtype = InfiniDtype .BF16 ,
273+ sync = None ,
274+ ):
275+ print (
276+ f"Testing w4 Gptq Qyblas Gemm on { InfiniDeviceNames [device ]} with M-K-N:{ M , K , N } , block_size:{ block_size } , weight dtype:{ InfiniDtypeNames [weight_dtype ]} , dtype:{ InfiniDtypeNames [dtype ]} "
277+ )
278+ quant_type = 0
279+ bit = 4
280+
281+ block_n , block_k = block_size [0 ], block_size [1 ]
282+ n_tiles = (N + block_n - 1 ) // block_n
283+ k_tiles = (K + block_k - 1 ) // block_k
284+
285+ A = TestTensor (
286+ (M , K ),
287+ None ,
288+ dtype ,
289+ device ,
290+ )
291+ if weight_dtype == InfiniDtype .I8 :
292+ _info = torch .iinfo (torch .int8 )
293+ elif weight_dtype == InfiniDtype .U8 :
294+ _info = torch .iinfo (torch .uint8 )
295+ elif weight_dtype == InfiniDtype .F8 :
296+ _info = torch .iinfo (float8_e4m3fn )
297+ # B_orig = TestTensor(
298+ # (N, K // 2),
299+ # None,
300+ # weight_dtype,
301+ # device,
302+ # randint_low=_info.min,
303+ # randint_high=_info.max,
304+ # )
305+ # B_torch = B_orig.torch_tensor().t()
306+ # B = TestTensor(
307+ # (K // 2, N),
308+ # B_torch.stride(),
309+ # weight_dtype,
310+ # device,
311+ # mode="manual",
312+ # set_tensor=B_torch,
313+ # )
314+
315+ B = TestTensor (
316+ (K // 2 , N ),
317+ None ,
318+ weight_dtype ,
319+ device ,
320+ randint_low = _info .min ,
321+ randint_high = _info .max ,
322+ )
323+
324+ b_scales = TestTensor (
325+ (k_tiles , N ),
326+ None ,
327+ dtype ,
328+ device ,
329+ )
330+
331+ b_zeros = TestTensor (
332+ (k_tiles , N ),
333+ None ,
334+ dtype ,
335+ device ,
336+ mode = "zeros" ,
337+ )
338+
339+ out = TestTensor (
340+ (M , N ),
341+ None ,
342+ dtype ,
343+ device ,
344+ mode = "zeros" ,
345+ )
346+
347+ print ("A" , A .torch_tensor ().shape , A .torch_tensor ().dtype , A .torch_tensor ().stride ())
348+ print ("B" , B .torch_tensor ().shape , B .torch_tensor ().dtype , B .torch_tensor ().stride ())
349+ print ("scales" , b_scales .torch_tensor ().shape , b_scales .torch_tensor ().dtype , b_scales .torch_tensor ().stride ())
350+ print ("zeros" , b_zeros .torch_tensor ().shape , b_zeros .torch_tensor ().dtype , b_zeros .torch_tensor ().stride ())
351+ print ("out" , out .torch_tensor ().shape , out .torch_tensor ().dtype , out .torch_tensor ().stride ())
352+
353+ if sync is not None :
354+ sync ()
355+
356+ descriptor = infiniopOperatorDescriptor_t ()
357+ check_error (
358+ LIBINFINIOP .infiniopCreateGptqQyblasGemmDescriptor (
359+ handle ,
360+ ctypes .byref (descriptor ),
361+ out .descriptor ,
362+ A .descriptor ,
363+ B .descriptor ,
364+ b_scales .descriptor ,
365+ b_zeros .descriptor ,
366+ )
367+ )
368+ # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
369+
370+ for tensor in [out , A , B , b_scales , b_zeros ]:
371+ tensor .destroy_desc ()
372+
373+ workspace_size = c_uint64 (0 )
374+ check_error (
375+ LIBINFINIOP .infiniopGetGptqQyblasGemmWorkspaceSize (
376+ descriptor , ctypes .byref (workspace_size )
377+ )
378+ )
379+ workspace = TestWorkspace (workspace_size .value , A .device )
380+
381+ def lib_gptq_qyblas_gemm ():
382+ check_error (
383+ LIBINFINIOP .infiniopGptqQyblasGemm (
384+ descriptor ,
385+ workspace .data (),
386+ workspace_size .value ,
387+ out .data (),
388+ A .data (),
389+ B .data (),
390+ b_scales .data (),
391+ b_zeros .data (),
392+ quant_type ,
393+ bit ,
394+ None ,
395+ )
396+ )
397+
398+ lib_gptq_qyblas_gemm ()
399+
400+ if sync is not None :
401+ sync ()
402+
403+ out_dtype = to_torch_dtype (dtype )
404+ ans = native_w8a16_block_int8_matmul (A .torch_tensor (), B_orig .torch_tensor (), b_scales .torch_tensor (), block_size , out_dtype )
405+
406+ rel_diff = (torch .mean (
407+ torch .abs (out .actual_tensor ().to (torch .float32 ) - ans .to (torch .float32 ))) /
408+ torch .mean (torch .abs (ans .to (torch .float32 ))))
409+
410+ assert rel_diff < 0.05
411+
412+
413+ # Profiling workflow
414+ if PROFILE :
415+ # fmt: off
416+ profile_operation ("PyTorch" , lambda : native_w8a16_block_int8_matmul (A .torch_tensor (), B_orig .torch_tensor (), b_scales .torch_tensor (), block_size , out_dtype ), device , NUM_PRERUN , NUM_ITERATIONS )
417+ profile_operation (" lib" , lambda : lib_gptq_qyblas_gemm (), device , NUM_PRERUN , NUM_ITERATIONS )
418+ # fmt: on
419+
420+ check_error (LIBINFINIOP .infiniopDestroyGptqQyblasGemmDescriptor (descriptor ))
421+
422+
257423if __name__ == "__main__" :
258424 args = get_args ()
259425
@@ -263,7 +429,9 @@ def lib_gptq_qyblas_gemm():
263429 NUM_PRERUN = args .num_prerun
264430 NUM_ITERATIONS = args .num_iterations
265431
432+ # for device in get_test_devices(args):
433+ # test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
266434 for device in get_test_devices (args ):
267- test_operator (device , test , _TEST_CASES , _TENSOR_DTYPES )
435+ test_operator (device , test_w4 , _TEST_CASES_W4 , _TENSOR_DTYPES_W4 )
268436
269437 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments