Skip to content

Commit c2ae381

Browse files
[ROCm] Make blocksize=64 default for 4bit (#1873)
* [ROCm] Make blocksize=64 default for 4bit * Update test
1 parent 943e42d commit c2ae381

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

bitsandbytes/functional.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
1717

18-
from .cextension import ROCM_WARP_SIZE_64, lib
18+
from .cextension import lib
1919

2020
name2qmap = {}
2121

@@ -869,8 +869,6 @@ def quantize_fp4(
869869
compress_statistics=False,
870870
quant_storage=torch.uint8,
871871
):
872-
if blocksize is None:
873-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
874872
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
875873

876874

@@ -882,8 +880,6 @@ def quantize_nf4(
882880
compress_statistics=False,
883881
quant_storage=torch.uint8,
884882
):
885-
if blocksize is None:
886-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
887883
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
888884

889885

@@ -905,7 +901,7 @@ def quantize_4bit(
905901
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
906902
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
907903
blocksize (`int`, *optional*):
908-
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
904+
The size of the blocks. Defaults to 64.
909905
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
910906
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
911907
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
@@ -921,7 +917,7 @@ def quantize_4bit(
921917
"""
922918

923919
if blocksize is None:
924-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
920+
blocksize = 64
925921

926922
input_shape = A.shape
927923

@@ -975,8 +971,6 @@ def dequantize_fp4(
975971
out: Optional[torch.Tensor] = None,
976972
blocksize: Optional[int] = None,
977973
) -> torch.Tensor:
978-
if blocksize is None:
979-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
980974
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
981975

982976

@@ -987,8 +981,6 @@ def dequantize_nf4(
987981
out: Optional[torch.Tensor] = None,
988982
blocksize: Optional[int] = None,
989983
) -> torch.Tensor:
990-
if blocksize is None:
991-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
992984
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
993985

994986

@@ -1016,7 +1008,7 @@ def dequantize_4bit(
10161008
Required if `quant_state` is not provided and ignored otherwise.
10171009
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
10181010
blocksize (`int`, *optional*):
1019-
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
1011+
The size of the blocks. Defaults to 64.
10201012
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
10211013
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
10221014
@@ -1028,7 +1020,7 @@ def dequantize_4bit(
10281020
"""
10291021

10301022
if blocksize is None:
1031-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
1023+
blocksize = 64
10321024

10331025
if quant_state is None:
10341026
assert absmax is not None and out is not None

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn.functional as F
1212

1313
import bitsandbytes as bnb
14-
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1514
from bitsandbytes.functional import (
1615
QuantState,
1716
_convert_weight_packed_for_cpu,
@@ -226,7 +225,7 @@ def __new__(
226225
data = torch.empty(0)
227226

228227
if blocksize is None:
229-
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
228+
blocksize = 64
230229

231230
self = torch.Tensor._make_subclass(cls, data, requires_grad)
232231
self.blocksize = blocksize

tests/test_parametrize.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ def __init__(self, device="cpu", dtype=torch.float32):
3737
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
3838
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
3939
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
40-
@pytest.mark.parametrize(
41-
"blocksize",
42-
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
43-
)
40+
@pytest.mark.parametrize("blocksize", [64, 128, 256])
4441
def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize):
4542
"""Test basic parameter replacement with 4-bit quantization on different dtypes."""
4643
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -267,7 +264,7 @@ def test_quant_state_preservation(device, dtype):
267264

268265
module = ParametrizeTestModule(device=device, dtype=dtype)
269266

270-
blocksize = 128 if ROCM_WARP_SIZE_64 else 64
267+
blocksize = 64
271268

272269
# Apply parametrization with specific settings
273270
replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)

0 commit comments

Comments
 (0)