@@ -25,20 +25,32 @@ def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
2525@pytest .mark .skipif (BNB_BACKEND != "CUDA" , reason = "this test requires a CUDA backend" )
2626def test_get_cuda_bnb_library_path_override (monkeypatch , cuda120_spec , caplog ):
2727 """BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary."""
28+ monkeypatch .delenv ("BNB_ROCM_VERSION" , raising = False )
2829 monkeypatch .setenv ("BNB_CUDA_VERSION" , "110" )
2930 assert get_cuda_bnb_library_path (cuda120_spec ).stem == "libbitsandbytes_cuda110"
3031 assert "BNB_CUDA_VERSION" in caplog .text # did we get the warning?
3132
3233
3334@pytest .mark .skipif (BNB_BACKEND != "CUDA" , reason = "this test requires a CUDA backend" )
3435def test_get_cuda_bnb_library_path_rejects_rocm_override (monkeypatch , cuda120_spec ):
35- """BNB_ROCM_VERSION should be rejected on CUDA with a helpful error."""
36+ """BNB_ROCM_VERSION alone should be rejected on CUDA with a helpful error."""
3637 monkeypatch .delenv ("BNB_CUDA_VERSION" , raising = False )
3738 monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
38- with pytest .raises (RuntimeError , match = r"BNB_ROCM_VERSION.*detected for CUDA " ):
39+ with pytest .raises (RuntimeError , match = r"BNB_ROCM_VERSION.*not a ROCm build " ):
3940 get_cuda_bnb_library_path (cuda120_spec )
4041
4142
43+ @pytest .mark .skipif (BNB_BACKEND != "CUDA" , reason = "this test requires a CUDA backend" )
44+ def test_get_cuda_bnb_library_path_cuda_override_takes_priority (monkeypatch , cuda120_spec , caplog ):
45+ """When both overrides are set on CUDA, the CUDA override wins and the ROCm one is warned about."""
46+ monkeypatch .setenv ("BNB_CUDA_VERSION" , "110" )
47+ monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
48+ assert get_cuda_bnb_library_path (cuda120_spec ).stem == "libbitsandbytes_cuda110"
49+ assert "BNB_CUDA_VERSION" in caplog .text
50+ assert "BNB_ROCM_VERSION" in caplog .text
51+ assert "ignored on this CUDA build" in caplog .text
52+
53+
4254@pytest .fixture
4355def rocm70_spec () -> CUDASpecs :
4456 """Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer."""
@@ -60,15 +72,27 @@ def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
6072@pytest .mark .skipif (BNB_BACKEND != "ROCm" , reason = "this test requires a ROCm backend" )
6173def test_get_rocm_bnb_library_path_override (monkeypatch , rocm70_spec , caplog ):
6274 """BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
75+ monkeypatch .delenv ("BNB_CUDA_VERSION" , raising = False )
6376 monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
6477 assert get_cuda_bnb_library_path (rocm70_spec ).stem == "libbitsandbytes_rocm72"
6578 assert "BNB_ROCM_VERSION" in caplog .text
6679
6780
81+ @pytest .mark .skipif (BNB_BACKEND != "ROCm" , reason = "this test requires a ROCm backend" )
82+ def test_get_rocm_bnb_library_path_rocm_override_takes_priority (monkeypatch , rocm70_spec , caplog ):
83+ """When both overrides are set on ROCm, the ROCm override wins and the CUDA one is warned about."""
84+ monkeypatch .setenv ("BNB_ROCM_VERSION" , "72" )
85+ monkeypatch .setenv ("BNB_CUDA_VERSION" , "110" )
86+ assert get_cuda_bnb_library_path (rocm70_spec ).stem == "libbitsandbytes_rocm72"
87+ assert "BNB_ROCM_VERSION" in caplog .text
88+ assert "BNB_CUDA_VERSION" in caplog .text
89+ assert "ignored on this ROCm build" in caplog .text
90+
91+
6892@pytest .mark .skipif (BNB_BACKEND != "ROCm" , reason = "this test requires a ROCm backend" )
6993def test_get_rocm_bnb_library_path_rejects_cuda_override (monkeypatch , rocm70_spec ):
70- """BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
94+ """BNB_CUDA_VERSION alone should be rejected on ROCm with a helpful error."""
7195 monkeypatch .delenv ("BNB_ROCM_VERSION" , raising = False )
72- monkeypatch .setenv ("BNB_CUDA_VERSION" , "120 " )
73- with pytest .raises (RuntimeError , match = r"BNB_CUDA_VERSION.*detected for ROCm " ):
96+ monkeypatch .setenv ("BNB_CUDA_VERSION" , "110 " )
97+ with pytest .raises (RuntimeError , match = r"BNB_CUDA_VERSION.*not a CUDA build " ):
7498 get_cuda_bnb_library_path (rocm70_spec )
0 commit comments