Skip to content

Commit d607127

Browse files
committed
add support for 64 block size on 32 warp size supported amd gpus
1 parent 39dd847 commit d607127

File tree

8 files changed

+63
-30
lines changed

8 files changed

+63
-30
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import HIP_ENVIRONMENT, lib
11+
from ...cextension import ROCM_WARP_SIZE_64, lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -211,7 +211,7 @@ def _get_col_absmax(
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213213

214-
if HIP_ENVIRONMENT:
214+
if ROCM_WARP_SIZE_64:
215215
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216216
else:
217217
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -269,7 +269,7 @@ def _(
269269
def _dequantize_blockwise_impl(
270270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
271271
) -> None:
272-
if HIP_ENVIRONMENT:
272+
if ROCM_WARP_SIZE_64:
273273
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274274
else:
275275
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
303303
def _(
304304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
305305
) -> tuple[torch.Tensor, torch.Tensor]:
306-
if HIP_ENVIRONMENT:
306+
if ROCM_WARP_SIZE_64:
307307
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308308
else:
309309
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
385385
dtype: torch.dtype,
386386
out: torch.Tensor,
387387
) -> None:
388-
if HIP_ENVIRONMENT:
388+
if ROCM_WARP_SIZE_64:
389389
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390390
else:
391391
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])

bitsandbytes/cextension.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
import torch
1010

1111
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
12-
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
12+
from bitsandbytes.cuda_specs import (
13+
CUDASpecs,
14+
get_cuda_specs,
15+
get_cuda_version_tuple,
16+
get_rocm_gpu_arch,
17+
get_rocm_warpsize,
18+
)
1319

1420
logger = logging.getLogger(__name__)
1521

@@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary:
298304

299305

300306
ROCM_GPU_ARCH = get_rocm_gpu_arch()
307+
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False
301308

302309
try:
303310
# to support Intel CPU/GPU (XPU) backend

bitsandbytes/cuda_specs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str:
100100
""",
101101
)
102102
return "unknown"
103+
104+
105+
def get_rocm_warpsize() -> int:
106+
"""Get ROCm warp size."""
107+
logger = logging.getLogger(__name__)
108+
try:
109+
if torch.version.hip:
110+
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
111+
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
112+
if match:
113+
return int(match.group(1))
114+
else:
115+
# default to 64 to be safe
116+
return 64
117+
else:
118+
# nvidia cards always use 32 warp size
119+
return 32
120+
except Exception as e:
121+
logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)")
122+
if torch.cuda.is_available():
123+
logger.warning(
124+
"""
125+
ROCm warp size detection failed despite ROCm being available.
126+
""",
127+
)
128+
return 64

bitsandbytes/functional.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
1717

18-
from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
18+
from .cextension import ROCM_WARP_SIZE_64, ipex_cpu, ipex_xpu, lib
1919

2020
name2qmap = {}
2121

@@ -804,7 +804,7 @@ def quantize_fp4(
804804
quant_storage=torch.uint8,
805805
):
806806
if blocksize is None:
807-
blocksize = 64 if not HIP_ENVIRONMENT else 128
807+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
808808
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
809809

810810

@@ -817,7 +817,7 @@ def quantize_nf4(
817817
quant_storage=torch.uint8,
818818
):
819819
if blocksize is None:
820-
blocksize = 64 if not HIP_ENVIRONMENT else 128
820+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
821821
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
822822

823823

@@ -855,7 +855,7 @@ def quantize_4bit(
855855
"""
856856

857857
if blocksize is None:
858-
blocksize = 64 if not HIP_ENVIRONMENT else 128
858+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
859859

860860
input_shape = A.shape
861861

@@ -910,7 +910,7 @@ def dequantize_fp4(
910910
blocksize: Optional[int] = None,
911911
) -> torch.Tensor:
912912
if blocksize is None:
913-
blocksize = 64 if not HIP_ENVIRONMENT else 128
913+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
914914
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
915915

916916

@@ -922,7 +922,7 @@ def dequantize_nf4(
922922
blocksize: Optional[int] = None,
923923
) -> torch.Tensor:
924924
if blocksize is None:
925-
blocksize = 64 if not HIP_ENVIRONMENT else 128
925+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
926926
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
927927

928928

@@ -962,7 +962,7 @@ def dequantize_4bit(
962962
"""
963963

964964
if blocksize is None:
965-
blocksize = 64 if not HIP_ENVIRONMENT else 128
965+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
966966

967967
if quant_state is None:
968968
assert absmax is not None and out is not None

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212

1313
import bitsandbytes as bnb
14-
from bitsandbytes.cextension import HIP_ENVIRONMENT
14+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1515
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
1616
from bitsandbytes.optim import GlobalOptimManager
1717
from bitsandbytes.utils import (
@@ -225,7 +225,7 @@ def __new__(
225225
data = torch.empty(0)
226226

227227
if blocksize is None:
228-
blocksize = 64 if not HIP_ENVIRONMENT else 128
228+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
229229

230230
self = torch.Tensor._make_subclass(cls, data, requires_grad)
231231
self.blocksize = blocksize

tests/test_functional.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import bitsandbytes as bnb
1111
from bitsandbytes import functional as F
12-
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
12+
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH, ROCM_WARP_SIZE_64
1313
from tests.helpers import (
1414
BOOLEAN_TUPLES,
1515
TRUE_FALSE,
@@ -95,7 +95,7 @@ class Test8BitBlockwiseQuantizeFunctional:
9595
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
9696
@pytest.mark.parametrize(
9797
"blocksize",
98-
[4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128],
98+
[4096, 2048, 1024, 512, 256, 128, 64] if not ROCM_WARP_SIZE_64 else [4096, 2048, 1024, 512, 256, 128],
9999
)
100100
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
101101
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
@@ -1107,7 +1107,7 @@ class TestQuantize4BitFunctional:
11071107
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
11081108
@pytest.mark.parametrize(
11091109
"blocksize",
1110-
[64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096],
1110+
[64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
11111111
)
11121112
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11131113
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -1174,7 +1174,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11741174

11751175
@pytest.mark.parametrize("device", get_available_devices())
11761176
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1177-
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
1177+
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
11781178
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
11791179
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
11801180
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -1241,7 +1241,7 @@ def test_bench_4bit_dequant(self, quant_type):
12411241
# print((time.time()-t0)/iters*1e6)
12421242

12431243
@pytest.mark.skipif(
1244-
HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
1244+
ROCM_WARP_SIZE_64, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
12451245
)
12461246
@pytest.mark.parametrize("device", get_available_devices())
12471247
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")

tests/test_linear4bit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
import bitsandbytes as bnb
11-
from bitsandbytes.cextension import HIP_ENVIRONMENT
11+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1212
from tests.helpers import (
1313
TRUE_FALSE,
1414
describe_dtype,
@@ -192,7 +192,7 @@ def test_linear_serialization(
192192

193193
@pytest.mark.parametrize("device", get_available_devices())
194194
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
195-
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
195+
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
196196
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
197197
def test_copy_param(device, quant_type, blocksize, compress_statistics):
198198
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -249,7 +249,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):
249249

250250
@pytest.mark.parametrize("device", get_available_devices())
251251
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
252-
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
252+
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
253253
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
254254
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
255255
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -278,7 +278,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
278278

279279
@pytest.mark.parametrize("device", get_available_devices())
280280
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
281-
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
281+
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
282282
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
283283
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
284284
if device == "hpu" and not is_supported_on_hpu(quant_type):

tests/test_ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
import bitsandbytes
7-
from bitsandbytes.cextension import HIP_ENVIRONMENT
7+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
88
from bitsandbytes.functional import ipex_xpu
99
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
1010

@@ -103,7 +103,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias):
103103
class TestInt8BlockwiseQuantOps:
104104
@pytest.mark.parametrize("device", get_available_devices())
105105
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
106-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
106+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
107107
def test_quantize_blockwise(self, device, dtype, blocksize):
108108
if device == "cpu":
109109
if dtype != torch.float32:
@@ -127,7 +127,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize):
127127

128128
@pytest.mark.parametrize("device", get_available_devices())
129129
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
130-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
130+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
131131
def test_dequantize_blockwise(self, device, dtype, blocksize):
132132
if device == "cpu" and dtype != torch.float32:
133133
pytest.skip("CPU implementation is only available for float32")
@@ -157,7 +157,7 @@ class Test4bitBlockwiseQuantOps:
157157
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
158158
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
159159
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
160-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
160+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
161161
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
162162
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
163163
pytest.skip("This configuration is not supported on HPU.")
@@ -181,7 +181,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
181181
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
182182
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
183183
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
184-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
184+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
185185
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
186186
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
187187
pytest.skip("This configuration is not supported on HPU.")
@@ -215,7 +215,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
215215
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
216216
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
217217
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
218-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
218+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
219219
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
220220
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
221221
pytest.skip("This configuration is not supported on HPU.")

0 commit comments

Comments
 (0)