Skip to content

Commit b546e82

Browse files
committed
Fix fragile override tests and restructure backend override logic
- Clear stray env vars in override tests to prevent false failures - Branch on active backend first to handle both-overrides-set case - Warn (instead of silently ignoring) the wrong override when both are set - Reject overrides on unsupported backends (e.g. XPU) - Add symmetric both-overrides-set tests for CUDA and ROCm
1 parent 3a8debc commit b546e82

2 files changed

Lines changed: 69 additions & 28 deletions

File tree

bitsandbytes/cextension.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,49 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
3333
cuda_override_value = os.environ.get("BNB_CUDA_VERSION")
3434
rocm_override_value = os.environ.get("BNB_ROCM_VERSION")
3535

36-
if rocm_override_value:
37-
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
38-
if torch.version.cuda:
39-
raise RuntimeError(
40-
f"BNB_ROCM_VERSION={rocm_override_value} detected for CUDA!\n"
41-
"Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=<version>\n"
42-
"Clear the variable and retry: unset BNB_ROCM_VERSION\n"
36+
if torch.version.hip:
37+
if cuda_override_value:
38+
if not rocm_override_value:
39+
raise RuntimeError(
40+
f"BNB_CUDA_VERSION={cuda_override_value} detected but this is not a CUDA build!\n"
41+
"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
42+
"Clear the variable and retry: unset BNB_CUDA_VERSION\n"
43+
)
44+
logger.warning(
45+
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} is set but ignored on this ROCm build. "
46+
"Clear the variable: unset BNB_CUDA_VERSION",
4347
)
44-
logger.warning(
45-
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
46-
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
47-
"If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n"
48-
)
49-
elif cuda_override_value:
50-
library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1)
51-
if torch.version.hip:
48+
if rocm_override_value:
49+
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
50+
logger.warning(
51+
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
52+
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
53+
"If this was unintended clear the variable and retry: unset BNB_ROCM_VERSION\n",
54+
)
55+
elif torch.version.cuda:
56+
if rocm_override_value:
57+
if not cuda_override_value:
58+
raise RuntimeError(
59+
f"BNB_ROCM_VERSION={rocm_override_value} detected but this is not a ROCm build!\n"
60+
"Use BNB_CUDA_VERSION instead: export BNB_CUDA_VERSION=<version>\n"
61+
"Clear the variable and retry: unset BNB_ROCM_VERSION\n"
62+
)
63+
logger.warning(
64+
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} is set but ignored on this CUDA build. "
65+
"Clear the variable: unset BNB_ROCM_VERSION",
66+
)
67+
if cuda_override_value:
68+
library_name = re.sub(r"cuda\d+", f"cuda{cuda_override_value}", library_name, count=1)
69+
logger.warning(
70+
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n"
71+
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
72+
"If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n",
73+
)
74+
else:
75+
if rocm_override_value or cuda_override_value:
5276
raise RuntimeError(
53-
f"BNB_CUDA_VERSION={cuda_override_value} detected for ROCm!\n"
54-
f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
55-
f"Clear the variable and retry: unset BNB_CUDA_VERSION\n"
77+
"BNB_ROCM_VERSION / BNB_CUDA_VERSION overrides are not supported on this backend.",
5678
)
57-
logger.warning(
58-
f"WARNING: BNB_CUDA_VERSION={cuda_override_value} environment variable detected; loading {library_name}.\n"
59-
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
60-
"If this was unintended clear the variable and retry: unset BNB_CUDA_VERSION\n"
61-
)
6279

6380
return PACKAGE_DIR / library_name
6481

tests/test_cuda_setup_evaluator.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,32 @@ def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
2525
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
2626
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
2727
"""BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary."""
28+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
2829
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
2930
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
3031
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
3132

3233

3334
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
3435
def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec):
35-
"""BNB_ROCM_VERSION should be rejected on CUDA with a helpful error."""
36+
"""BNB_ROCM_VERSION alone should be rejected on CUDA with a helpful error."""
3637
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
3738
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
38-
with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*detected for CUDA"):
39+
with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*not a ROCm build"):
3940
get_cuda_bnb_library_path(cuda120_spec)
4041

4142

43+
@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend")
44+
def test_get_cuda_bnb_library_path_cuda_override_takes_priority(monkeypatch, cuda120_spec, caplog):
45+
"""When both overrides are set on CUDA, the CUDA override wins and the ROCm one is warned about."""
46+
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
47+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
48+
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
49+
assert "BNB_CUDA_VERSION" in caplog.text
50+
assert "BNB_ROCM_VERSION" in caplog.text
51+
assert "ignored on this CUDA build" in caplog.text
52+
53+
4254
@pytest.fixture
4355
def rocm70_spec() -> CUDASpecs:
4456
"""Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer."""
@@ -60,15 +72,27 @@ def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
6072
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
6173
def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog):
6274
"""BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
75+
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
6376
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
6477
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
6578
assert "BNB_ROCM_VERSION" in caplog.text
6679

6780

81+
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
82+
def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog):
83+
"""When both overrides are set on ROCm, the ROCm override wins and the CUDA one is warned about."""
84+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
85+
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
86+
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
87+
assert "BNB_ROCM_VERSION" in caplog.text
88+
assert "BNB_CUDA_VERSION" in caplog.text
89+
assert "ignored on this ROCm build" in caplog.text
90+
91+
6892
@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend")
6993
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
70-
"""BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
94+
"""BNB_CUDA_VERSION alone should be rejected on ROCm with a helpful error."""
7195
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
72-
monkeypatch.setenv("BNB_CUDA_VERSION", "120")
73-
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"):
96+
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
97+
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*not a CUDA build"):
7498
get_cuda_bnb_library_path(rocm70_spec)

0 commit comments

Comments
 (0)