44import torch
55
66import bitsandbytes
7- from bitsandbytes .cextension import HIP_ENVIRONMENT
7+ from bitsandbytes .cextension import ROCM_WARP_SIZE_64
88from bitsandbytes .functional import ipex_xpu
99from tests .helpers import TRUE_FALSE , get_available_devices , id_formatter , is_supported_on_hpu
1010
@@ -103,7 +103,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias):
103103class TestInt8BlockwiseQuantOps :
104104 @pytest .mark .parametrize ("device" , get_available_devices ())
105105 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
106- @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not HIP_ENVIRONMENT else [128 , 256 , 512 ])
106+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not ROCM_WARP_SIZE_64 else [128 , 256 , 512 ])
107107 def test_quantize_blockwise (self , device , dtype , blocksize ):
108108 if device == "cpu" :
109109 if dtype != torch .float32 :
@@ -127,7 +127,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize):
127127
128128 @pytest .mark .parametrize ("device" , get_available_devices ())
129129 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
130- @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not HIP_ENVIRONMENT else [128 , 256 , 512 ])
130+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not ROCM_WARP_SIZE_64 else [128 , 256 , 512 ])
131131 def test_dequantize_blockwise (self , device , dtype , blocksize ):
132132 if device == "cpu" and dtype != torch .float32 :
133133 pytest .skip ("CPU implementation is only available for float32" )
@@ -157,7 +157,7 @@ class Test4bitBlockwiseQuantOps:
157157 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
158158 @pytest .mark .parametrize ("storage_dtype" , [torch .uint8 , torch .bfloat16 ], ids = id_formatter ("storage_dtype" ))
159159 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
160- @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not HIP_ENVIRONMENT else [128 , 256 , 512 ])
160+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not ROCM_WARP_SIZE_64 else [128 , 256 , 512 ])
161161 def test_quantize_4bit (self , device , dtype , storage_dtype , quant_type , blocksize ):
162162 if device == "hpu" and not is_supported_on_hpu (quant_type , dtype , storage_dtype ):
163163 pytest .skip ("This configuration is not supported on HPU." )
@@ -181,7 +181,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
181181 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
182182 @pytest .mark .parametrize ("storage_dtype" , [torch .uint8 , torch .bfloat16 ], ids = id_formatter ("storage_dtype" ))
183183 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
184- @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not HIP_ENVIRONMENT else [128 , 256 , 512 ])
184+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not ROCM_WARP_SIZE_64 else [128 , 256 , 512 ])
185185 def test_dequantize_4bit (self , device , dtype , storage_dtype , quant_type , blocksize ):
186186 if device == "hpu" and not is_supported_on_hpu (quant_type , dtype , storage_dtype ):
187187 pytest .skip ("This configuration is not supported on HPU." )
@@ -215,7 +215,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
215215 @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = id_formatter ("dtype" ))
216216 @pytest .mark .parametrize ("storage_dtype" , [torch .uint8 , torch .bfloat16 ], ids = id_formatter ("storage_dtype" ))
217217 @pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
218- @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not HIP_ENVIRONMENT else [128 , 256 , 512 ])
218+ @pytest .mark .parametrize ("blocksize" , [64 , 128 , 256 , 512 ] if not ROCM_WARP_SIZE_64 else [128 , 256 , 512 ])
219219 def test_gemv_4bit (self , device , dtype , storage_dtype , quant_type , blocksize ):
220220 if device == "hpu" and not is_supported_on_hpu (quant_type , dtype , storage_dtype ):
221221 pytest .skip ("This configuration is not supported on HPU." )
0 commit comments