Skip to content

Commit 44fad6c

Browse files
committed
Enable previously skipped tests on CDNA (warp size 64)
Remove ROCM_WARP_SIZE_64 guards from all test files now that blocksize-32/64 quantization and GEMV kernels work on warp-64 hardware.
1 parent e3b9c46 commit 44fad6c

File tree

5 files changed

+13
-29
lines changed

5 files changed

+13
-29
lines changed

tests/test_functional.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import bitsandbytes as bnb
1212
from bitsandbytes import functional as F
13-
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1413
from tests.helpers import (
1514
BOOLEAN_TUPLES,
1615
TRUE_FALSE,
@@ -96,7 +95,7 @@ class Test8BitBlockwiseQuantizeFunctional:
9695
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
9796
@pytest.mark.parametrize(
9897
"blocksize",
99-
[4096, 2048, 1024, 512, 256, 128, 64] if not ROCM_WARP_SIZE_64 else [4096, 2048, 1024, 512, 256, 128],
98+
[4096, 2048, 1024, 512, 256, 128, 64],
10099
)
101100
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
102101
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
@@ -509,7 +508,6 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
509508
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
510509
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
511510
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
512-
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
513511
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
514512
def min_max(x):
515513
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -844,7 +842,7 @@ class TestQuantize4BitFunctional:
844842
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
845843
@pytest.mark.parametrize(
846844
"blocksize",
847-
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512, 1024, 2048, 4096],
845+
[32, 64, 128, 256, 512, 1024, 2048, 4096],
848846
)
849847
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
850848
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -927,9 +925,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
927925

928926
@pytest.mark.parametrize("device", get_available_devices())
929927
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
930-
@pytest.mark.parametrize(
931-
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize")
932-
)
928+
@pytest.mark.parametrize("blocksize", [32, 64, 128], ids=id_formatter("blocksize"))
933929
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
934930
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
935931
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
@@ -966,9 +962,7 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
966962
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
967963
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
968964
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
969-
@pytest.mark.parametrize(
970-
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128], ids=id_formatter("blocksize")
971-
)
965+
@pytest.mark.parametrize("blocksize", [32, 64, 128], ids=id_formatter("blocksize"))
972966
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
973967
"""
974968
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
@@ -1028,9 +1022,6 @@ def test_bench_4bit_dequant(self, quant_type):
10281022
# torch.cuda.synchronize()
10291023
# print((time.time()-t0)/iters*1e6)
10301024

1031-
@pytest.mark.skipif(
1032-
ROCM_WARP_SIZE_64, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
1033-
)
10341025
@pytest.mark.parametrize("device", get_available_devices())
10351026
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
10361027
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
@@ -1185,7 +1176,6 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
11851176
@pytest.mark.parametrize("device", get_available_devices())
11861177
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
11871178
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
1188-
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
11891179
def test_gemv_eye_4bit(self, device, storage_type, dtype):
11901180
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
11911181
pytest.skip("This configuration is not supported on HPU.")

tests/test_linear4bit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212

1313
import bitsandbytes as bnb
14-
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1514
from tests.helpers import (
1615
TRUE_FALSE,
1716
describe_dtype,
@@ -195,7 +194,7 @@ def test_linear_serialization(
195194

196195
@pytest.mark.parametrize("device", get_available_devices())
197196
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
198-
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
197+
@pytest.mark.parametrize("blocksize", [32, 64, 128])
199198
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
200199
def test_copy_param(device, quant_type, blocksize, compress_statistics):
201200
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -286,7 +285,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
286285

287286
@pytest.mark.parametrize("device", get_available_devices())
288287
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
289-
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
288+
@pytest.mark.parametrize("blocksize", [32, 64, 128])
290289
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
291290
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
292291
if device == "hpu" and not is_supported_on_hpu(quant_type):
@@ -315,7 +314,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
315314

316315
@pytest.mark.parametrize("device", get_available_devices())
317316
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
318-
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [64, 128])
317+
@pytest.mark.parametrize("blocksize", [32, 64, 128])
319318
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
320319
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
321320
if device == "hpu" and not is_supported_on_hpu(quant_type):

tests/test_linear8bitlt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111

1212
import bitsandbytes as bnb
13-
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1413
from bitsandbytes.nn.modules import Linear8bitLt
1514
from tests.helpers import (
1615
TRUE_FALSE,
@@ -238,7 +237,6 @@ def test_linear8bit_serialization(linear8bit):
238237
@pytest.mark.skipif(
239238
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
240239
)
241-
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
242240
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
243241
if device == "cuda" and platform.system() == "Windows":
244242
pytest.skip("Triton is not officially supported on Windows")

tests/test_ops.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55

66
import bitsandbytes
7-
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
87
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
98

109
# torch.library.opcheck is only available in torch 2.4 and later.
@@ -102,7 +101,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias):
102101
class TestInt8BlockwiseQuantOps:
103102
@pytest.mark.parametrize("device", get_available_devices())
104103
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
105-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
104+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
106105
def test_quantize_blockwise(self, device, dtype, blocksize):
107106
if device == "cpu":
108107
if dtype != torch.float32:
@@ -126,7 +125,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize):
126125

127126
@pytest.mark.parametrize("device", get_available_devices())
128127
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
129-
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
128+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
130129
def test_dequantize_blockwise(self, device, dtype, blocksize):
131130
if device == "cpu" and dtype != torch.float32:
132131
pytest.skip("CPU implementation is only available for float32")
@@ -152,7 +151,7 @@ class Test4bitBlockwiseQuantOps:
152151
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
153152
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
154153
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
155-
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512])
154+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512])
156155
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
157156
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
158157
pytest.skip("This configuration is not supported on HPU.")
@@ -202,7 +201,7 @@ def test_quantize_4bit_not_divisible_by_blocksize(self, device, dtype, quant_typ
202201
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
203202
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
204203
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
205-
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [64, 128, 256, 512])
204+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512])
206205
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
207206
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
208207
pytest.skip("This configuration is not supported on HPU.")
@@ -236,8 +235,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
236235
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
237236
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
238237
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
239-
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
240-
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
238+
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512])
241239
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
242240
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
243241
pytest.skip("This configuration is not supported on HPU.")

tests/test_parametrize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch.nn as nn
44

55
from bitsandbytes import functional as F
6-
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
76
from bitsandbytes.nn.parametrize import (
87
Bnb4bitParametrization,
98
replace_parameter_4bit,
@@ -336,7 +335,7 @@ def test_multiple_parameters(device, dtype):
336335
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
337336
@pytest.mark.parametrize(
338337
"blocksize",
339-
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
338+
[64, 128, 256],
340339
)
341340
def test_different_blocksizes(device, dtype, blocksize):
342341
"""Test parametrization with different block sizes to verify flexibility."""

0 commit comments

Comments
 (0)