1515
1616from bitsandbytes .utils import pack_dict_to_tensor , unpack_tensor_to_dict
1717
18- from .cextension import ROCM_WARP_SIZE_64 , lib
18+ from .cextension import lib
1919
2020name2qmap = {}
2121
@@ -869,8 +869,6 @@ def quantize_fp4(
869869 compress_statistics = False ,
870870 quant_storage = torch .uint8 ,
871871):
872- if blocksize is None :
873- blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
874872 return quantize_4bit (A , absmax , out , blocksize , compress_statistics , "fp4" , quant_storage )
875873
876874
@@ -882,8 +880,6 @@ def quantize_nf4(
882880 compress_statistics = False ,
883881 quant_storage = torch .uint8 ,
884882):
885- if blocksize is None :
886- blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
887883 return quantize_4bit (A , absmax , out , blocksize , compress_statistics , "nf4" , quant_storage )
888884
889885
@@ -905,7 +901,7 @@ def quantize_4bit(
905901 absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
906902 out (`torch.Tensor`, *optional*): A tensor to use to store the result.
907903 blocksize (`int`, *optional*):
908- The size of the blocks. Defaults to 128 on ROCm and 64 otherwise .
904+ The size of the blocks. Defaults to 64 .
909905 Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
910906 compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
911907 quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
@@ -921,7 +917,7 @@ def quantize_4bit(
921917 """
922918
923919 if blocksize is None :
924- blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
920+ blocksize = 64
925921
926922 input_shape = A .shape
927923
@@ -975,8 +971,6 @@ def dequantize_fp4(
975971 out : Optional [torch .Tensor ] = None ,
976972 blocksize : Optional [int ] = None ,
977973) -> torch .Tensor :
978- if blocksize is None :
979- blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
980974 return dequantize_4bit (A , quant_state , absmax , out , blocksize , "fp4" )
981975
982976
@@ -987,8 +981,6 @@ def dequantize_nf4(
987981 out : Optional [torch .Tensor ] = None ,
988982 blocksize : Optional [int ] = None ,
989983) -> torch .Tensor :
990- if blocksize is None :
991- blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
992984 return dequantize_4bit (A , quant_state , absmax , out , blocksize , "nf4" )
993985
994986
@@ -1016,7 +1008,7 @@ def dequantize_4bit(
10161008 Required if `quant_state` is not provided and ignored otherwise.
10171009 out (`torch.Tensor`, *optional*): A tensor to use to store the result.
10181010 blocksize (`int`, *optional*):
1019- The size of the blocks. Defaults to 128 on ROCm and 64 otherwise .
1011+ The size of the blocks. Defaults to 64 .
10201012 Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
10211013 quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
10221014
@@ -1028,7 +1020,7 @@ def dequantize_4bit(
10281020 """
10291021
10301022 if blocksize is None :
1031- blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
1023+ blocksize = 64
10321024
10331025 if quant_state is None :
10341026 assert absmax is not None and out is not None
0 commit comments