Skip to content

Commit 6a18715

Browse files
authored
Update CUDA/ROCm setup tests (#1899)
* Update CUDA/ROCm setup tests * Fix pre-commit * Add the assert that was removed by mistake * Fix fragile override tests and restructure backend override logic - Clear stray env vars in override tests to prevent false failures - Branch on active backend first to handle both-overrides-set case - Warn (instead of silently ignoring) the wrong override when both are set - Reject overrides on unsupported backends (e.g. XPU) - Add symmetric both-overrides-set tests for CUDA and ROCm
1 parent 45bd314 commit 6a18715

2 files changed

Lines changed: 90 additions & 41 deletions

File tree

bitsandbytes/cextension.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,52 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
3030
prefix = "rocm" if torch.version.hip else "cuda"
3131
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
3232

33-
override_value = os.environ.get("BNB_CUDA_VERSION")
33+
cuda_override_value = os.environ.get("BNB_CUDA_VERSION")
3434
rocm_override_value = os.environ.get("BNB_ROCM_VERSION")
3535

36-
if rocm_override_value and torch.version.hip:
37-
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
38-
logger.warning(
39-
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
40-
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
41-
"If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n"
42-
)
43-
elif override_value:
44-
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
45-
if torch.version.hip:
36+
if torch.version.hip:
37+
if cuda_override_value:
38+
if not rocm_override_value:
39+
raise RuntimeError(
40+
f"BNB_CUDA_VERSION={cuda_override_value} detected but this is not a CUDA build!\n"
41+
"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
42+
"Clear the variable and retry: unset BNB_CUDA_VERSION\n"
43+
)
44+
logger.warning(
45+
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} is set but ignored on this ROCm build. "
46+
"Clear the variable: unset BNB_CUDA_VERSION",
47+
)
48+
if rocm_override_value:
49+
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
50+
logger.warning(
51+
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
52+
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
53+
"If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n",
54+
)
55+
elif torch.version.cuda:
56+
if rocm_override_value:
57+
if not cuda_override_value:
58+
raise RuntimeError(
59+
f"BNB_ROCM_VERSION={rocm_override_value} detected but this is not a ROCm build!\n"
60+
"Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=<version>\n"
61+
"Clear the variable and retry: unset BNB_ROCM_VERSION\n"
62+
)
63+
logger.warning(
64+
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} is set but ignored on this CUDA build. "
65+
"Clear the variable: unset BNB_ROCM_VERSION",
66+
)
67+
if cuda_override_value:
68+
library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1)
69+
logger.warning(
70+
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n"
71+
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
72+
"If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n",
73+
)
74+
else:
75+
if rocm_override_value or cuda_override_value:
4676
raise RuntimeError(
47-
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
48-
f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
49-
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
77+
"BNB_ROCM_VERSION / BNB_CUDA_VERSION overrides are not supported on this backend.",
5078
)
51-
logger.warning(
52-
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
53-
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
54-
"If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n"
55-
)
5679

5780
return PACKAGE_DIR / library_name
5881

tests/test_cuda_setup_evaluator.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,98 @@
11
import pytest
22

3-
from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
3+
from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path
44
from bitsandbytes.cuda_specs import CUDASpecs
55

66

77
@pytest.fixture
88
def cuda120_spec() -> CUDASpecs:
9+
"""Simulates torch+cuda12.0 and a representative Ampere-class capability."""
910
return CUDASpecs(
1011
cuda_version_string="120",
1112
highest_compute_capability=(8, 6),
1213
cuda_version_tuple=(12, 0),
1314
)
1415

1516

16-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
17+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
1718
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
19+
"""Without overrides, library path uses the detected CUDA 12.0 version."""
20+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
1821
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
1922
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
2023

2124

22-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
25+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
2326
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
27+
"""BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary."""
28+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
2429
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
2530
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
2631
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
2732

2833

29-
# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2
34+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
35+
def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec):
36+
"""BNB_ROCM_VERSION alone should be rejected on CUDA with a helpful error."""
37+
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
38+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
39+
with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*not a ROCm build"):
40+
get_cuda_bnb_library_path(cuda120_spec)
41+
42+
43+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
44+
def test_get_cuda_bnb_library_path_cuda_override_takes_priority(monkeypatch, cuda120_spec, caplog):
45+
"""When both overrides are set on CUDA, the CUDA override wins and the ROCm one is warned about."""
46+
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
47+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
48+
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
49+
assert "BNB_CUDA_VERSION" in caplog.text
50+
assert "BNB_ROCM_VERSION" in caplog.text
51+
assert "ignored on this CUDA build" in caplog.text
52+
53+
3054
@pytest.fixture
3155
def rocm70_spec() -> CUDASpecs:
56+
"""Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer."""
3257
return CUDASpecs(
33-
cuda_version_string="70", # from torch.version.hip == "7.0.x"
34-
highest_compute_capability=(0, 0), # unused for ROCm library path resolution
58+
cuda_version_string="70",
59+
highest_compute_capability=(0, 0),
3560
cuda_version_tuple=(7, 0),
3661
)
3762

3863

39-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
64+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
4065
def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
4166
"""Without override, library path uses PyTorch's ROCm 7.0 version."""
4267
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
4368
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
4469
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70"
4570

4671

47-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
72+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
4873
def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog):
4974
"""BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
50-
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
5175
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
76+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
5277
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
5378
assert "BNB_ROCM_VERSION" in caplog.text
5479

5580

56-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
57-
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
58-
"""BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
59-
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
60-
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
61-
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"):
62-
get_cuda_bnb_library_path(rocm70_spec)
63-
64-
65-
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
81+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
6682
def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog):
67-
"""When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True."""
83+
"""When both overrides are set on ROCm, the ROCm override wins and the CUDA one is warned about."""
6884
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
69-
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
85+
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
7086
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
7187
assert "BNB_ROCM_VERSION" in caplog.text
72-
assert "BNB_CUDA_VERSION" not in caplog.text
88+
assert "BNB_CUDA_VERSION" in caplog.text
89+
assert "ignored on this ROCm build" in caplog.text
90+
91+
92+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
93+
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
94+
"""BNB_CUDA_VERSION alone should be rejected on ROCm with a helpful error."""
95+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
96+
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
97+
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*not a CUDA build"):
98+
get_cuda_bnb_library_path(rocm70_spec)

0 commit comments

Comments
 (0)