|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path |
| 3 | +from bitsandbytes.cextension import BNB_BACKEND, get_cuda_bnb_library_path |
4 | 4 | from bitsandbytes.cuda_specs import CUDASpecs |
5 | 5 |
|
6 | 6 |
|
7 | 7 | @pytest.fixture |
8 | 8 | def cuda120_spec() -> CUDASpecs: |
| 9 | + """Simulates torch+cuda12.0 and a representative Ampere-class capability.""" |
9 | 10 | return CUDASpecs( |
10 | 11 | cuda_version_string="120", |
11 | 12 | highest_compute_capability=(8, 6), |
12 | 13 | cuda_version_tuple=(12, 0), |
13 | 14 | ) |
14 | 15 |
|
15 | 16 |
|
16 | | -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") |
| 17 | +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") |
17 | 18 | def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): |
| 19 | + """Without overrides, library path uses the detected CUDA 12.0 version.""" |
| 20 | + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) |
18 | 21 | monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) |
19 | 22 | assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" |
20 | 23 |
|
21 | 24 |
|
22 | | -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") |
| 25 | +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") |
23 | 26 | def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): |
| 27 | + """BNB_CUDA_VERSION=110 overrides path selection to the CUDA 11.0 binary.""" |
| 28 | + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) |
24 | 29 | monkeypatch.setenv("BNB_CUDA_VERSION", "110") |
25 | 30 | assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" |
26 | 31 | assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? |
27 | 32 |
|
28 | 33 |
|
29 | | -# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2 |
| 34 | +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") |
| 35 | +def test_get_cuda_bnb_library_path_rejects_rocm_override(monkeypatch, cuda120_spec): |
| 36 | + """BNB_ROCM_VERSION alone should be rejected on CUDA with a helpful error.""" |
| 37 | + monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) |
| 38 | + monkeypatch.setenv("BNB_ROCM_VERSION", "72") |
| 39 | + with pytest.raises(RuntimeError, match=r"BNB_ROCM_VERSION.*not a ROCm build"): |
| 40 | + get_cuda_bnb_library_path(cuda120_spec) |
| 41 | + |
| 42 | + |
| 43 | +@pytest.mark.skipif(BNB_BACKEND != "CUDA", reason="this test requires a CUDA backend") |
| 44 | +def test_get_cuda_bnb_library_path_cuda_override_takes_priority(monkeypatch, cuda120_spec, caplog): |
| 45 | + """When both overrides are set on CUDA, the CUDA override wins and the ROCm one is warned about.""" |
| 46 | + monkeypatch.setenv("BNB_CUDA_VERSION", "110") |
| 47 | + monkeypatch.setenv("BNB_ROCM_VERSION", "72") |
| 48 | + assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" |
| 49 | + assert "BNB_CUDA_VERSION" in caplog.text |
| 50 | + assert "BNB_ROCM_VERSION" in caplog.text |
| 51 | + assert "ignored on this CUDA build" in caplog.text |
| 52 | + |
| 53 | + |
30 | 54 | @pytest.fixture |
31 | 55 | def rocm70_spec() -> CUDASpecs: |
| 56 | + """Simulates torch+rocm7.0 (bundled ROCm) when the system ROCm is newer.""" |
32 | 57 | return CUDASpecs( |
33 | | - cuda_version_string="70", # from torch.version.hip == "7.0.x" |
34 | | - highest_compute_capability=(0, 0), # unused for ROCm library path resolution |
| 58 | + cuda_version_string="70", |
| 59 | + highest_compute_capability=(0, 0), |
35 | 60 | cuda_version_tuple=(7, 0), |
36 | 61 | ) |
37 | 62 |
|
38 | 63 |
|
39 | | -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") |
| 64 | +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") |
40 | 65 | def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec): |
41 | 66 | """Without override, library path uses PyTorch's ROCm 7.0 version.""" |
42 | 67 | monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) |
43 | 68 | monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) |
44 | 69 | assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70" |
45 | 70 |
|
46 | 71 |
|
47 | | -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") |
| 72 | +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") |
48 | 73 | def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog): |
49 | 74 | """BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0.""" |
50 | | - monkeypatch.setenv("BNB_ROCM_VERSION", "72") |
51 | 75 | monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) |
| 76 | + monkeypatch.setenv("BNB_ROCM_VERSION", "72") |
52 | 77 | assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" |
53 | 78 | assert "BNB_ROCM_VERSION" in caplog.text |
54 | 79 |
|
55 | 80 |
|
56 | | -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") |
57 | | -def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec): |
58 | | - """BNB_CUDA_VERSION should be rejected on ROCm with a helpful error.""" |
59 | | - monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) |
60 | | - monkeypatch.setenv("BNB_CUDA_VERSION", "72") |
61 | | - with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"): |
62 | | - get_cuda_bnb_library_path(rocm70_spec) |
63 | | - |
64 | | - |
65 | | -@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm") |
| 81 | +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") |
66 | 82 | def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog): |
67 | | - """When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True.""" |
| 83 | + """When both overrides are set on ROCm, the ROCm override wins and the CUDA one is warned about.""" |
68 | 84 | monkeypatch.setenv("BNB_ROCM_VERSION", "72") |
69 | | - monkeypatch.setenv("BNB_CUDA_VERSION", "72") |
| 85 | + monkeypatch.setenv("BNB_CUDA_VERSION", "110") |
70 | 86 | assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72" |
71 | 87 | assert "BNB_ROCM_VERSION" in caplog.text |
72 | | - assert "BNB_CUDA_VERSION" not in caplog.text |
| 88 | + assert "BNB_CUDA_VERSION" in caplog.text |
| 89 | + assert "ignored on this ROCm build" in caplog.text |
| 90 | + |
| 91 | + |
| 92 | +@pytest.mark.skipif(BNB_BACKEND != "ROCm", reason="this test requires a ROCm backend") |
| 93 | +def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec): |
| 94 | + """BNB_CUDA_VERSION alone should be rejected on ROCm with a helpful error.""" |
| 95 | + monkeypatch.delenv("BNB_ROCM_VERSION", raising=False) |
| 96 | + monkeypatch.setenv("BNB_CUDA_VERSION", "110") |
| 97 | + with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*not a CUDA build"): |
| 98 | + get_cuda_bnb_library_path(rocm70_spec) |
0 commit comments