Skip to content

Commit 746cd64

Browse files
authored
Add BNB_ROCM_VERSION and ROCM_VERSION for ROCm/PyTorch version mismatch (#1878)
* Add BNB_ROCM_VERSION and ROCM_VERSION for ROCm/PyTorch version mismatch When PyTorch is built with a different ROCm version than the system (e.g. torch+rocm7.0 on ROCm 7.2), bitsandbytes fails to find the native library because the build uses hipconfig (system) while runtime uses torch.version.hip (PyTorch). - Add BNB_ROCM_VERSION env var (runtime): override which ROCm library is loaded, analogous to BNB_CUDA_VERSION. Takes priority when both BNB_ROCM_VERSION and BNB_CUDA_VERSION are set on ROCm. - Add ROCM_VERSION CMake cache variable (build): override the version shortcode in the output library name (e.g. -DROCM_VERSION=70 produces libbitsandbytes_rocm70.so on a 7.2 system). - Update diagnostics and error messages to mention BNB_ROCM_VERSION; align _print_hip_runtime_diagnostics with _print_cuda_runtime_diagnostics. - Reject BNB_CUDA_VERSION on ROCm with a clear error pointing to BNB_ROCM_VERSION. - Add ROCm tests: default path, override, rejection of BNB_CUDA_VERSION, and both vars set (ROCM wins). Fixes ROCm#82. * Lint
1 parent 67755b9 commit 746cd64

File tree

5 files changed

+98
-15
lines changed

5 files changed

+98
-15
lines changed

CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
# Separate by semicolons, i.e. `-DCOMPUTE_CAPABILITY=89;90;100;120`
1111
# Check your compute capability here: https://developer.nvidia.com/cuda-gpus
1212
# - PTXAS_VERBOSE: Pass the `-v` option to the PTX Assembler
13+
# - ROCM_VERSION: Override the ROCm version shortcode used in the output library name.
14+
# Useful when PyTorch was built against a different ROCm version than the
15+
# system install. For example, `-DROCM_VERSION=70` produces
16+
# libbitsandbytes_rocm70.so even if the system has ROCm 7.2.
1317
cmake_minimum_required(VERSION 3.22.1)
1418

1519
project(bitsandbytes LANGUAGES CXX)
@@ -222,7 +226,15 @@ elseif(BUILD_HIP)
222226
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
223227
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")
224228

225-
string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
229+
# Expose a cache variable that the user can set to override the ROCm version in the library name
230+
set(ROCM_VERSION "${HIP_VERSION_SHORT}" CACHE STRING "Expected ROCm Version Shortcode")
231+
232+
message(STATUS "ROCm Version: ${HIP_VERSION_SHORT} (from hipconfig)")
233+
if(NOT ROCM_VERSION STREQUAL "${HIP_VERSION_SHORT}")
234+
message(WARNING "Overriding ROCm version in library name: ${HIP_VERSION_SHORT} -> ${ROCM_VERSION}")
235+
endif()
236+
237+
string(APPEND BNB_OUTPUT_NAME "${ROCM_VERSION}")
226238
add_compile_definitions(__HIP_PLATFORM_AMD__)
227239
add_compile_definitions(__HIP_PLATFORM_HCC__)
228240
add_compile_definitions(BUILD_HIP)

agents/architecture_guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ GPU-specific functions are actually invoked.
329329
### Environment variables
330330

331331
- `BNB_CUDA_VERSION` — Override the auto-detected CUDA version for library selection
332+
- `BNB_ROCM_VERSION` is the ROCm equivalent
332333
- Standard CUDA env vars (`CUDA_HOME`, `LD_LIBRARY_PATH`) affect library discovery
333334

334335
---

bitsandbytes/cextension.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,21 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
3232
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
3333

3434
override_value = os.environ.get("BNB_CUDA_VERSION")
35-
if override_value:
35+
rocm_override_value = os.environ.get("BNB_ROCM_VERSION")
36+
37+
if rocm_override_value and torch.version.hip:
38+
library_name = re.sub(r"rocm\d+", f"rocm{rocm_override_value}", library_name, count=1)
39+
logger.warning(
40+
f"WARNING: BNB_ROCM_VERSION={rocm_override_value} environment variable detected; loading {library_name}.\n"
41+
"This can be used to load a bitsandbytes version built with a ROCm version that is different from the PyTorch ROCm version.\n"
42+
"If this was unintended set the BNB_ROCM_VERSION variable to an empty string: export BNB_ROCM_VERSION=\n"
43+
)
44+
elif override_value:
3645
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
3746
if torch.version.hip:
3847
raise RuntimeError(
3948
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
49+
f"Use BNB_ROCM_VERSION instead: export BNB_ROCM_VERSION=<version>\n"
4050
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
4151
)
4252
logger.warning(
@@ -122,7 +132,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
122132
1. Missing shared library dependencies (e.g., libcudart.so not in LD_LIBRARY_PATH or through PyTorch CUDA installation)
123133
2. CUDA version mismatch between PyTorch and available pre-compiled binaries
124134
3. Completely missing pre-compiled binaries when CUDA is detected
125-
4. Custom BNB_CUDA_VERSION override but mismatch
135+
4. Custom BNB_CUDA_VERSION or BNB_ROCM_VERSION override but mismatch
126136
5. CPU-only installation attempts when GPU functionality is requested
127137
128138
"""
@@ -131,7 +141,9 @@ def __init__(self, error_msg: str):
131141
self.error_msg = error_msg
132142
self.user_cuda_version = get_cuda_version_tuple()
133143
self.available_versions = get_available_cuda_binary_versions()
134-
self.override_value = os.environ.get("BNB_CUDA_VERSION")
144+
self.override_value = (
145+
os.environ.get("BNB_ROCM_VERSION") if HIP_ENVIRONMENT else os.environ.get("BNB_CUDA_VERSION")
146+
)
135147
self.requested_version = (
136148
parse_cuda_version(self.override_value)
137149
if self.override_value
@@ -217,8 +229,10 @@ def _format_lib_error_message(
217229
)
218230
if not HIP_ENVIRONMENT
219231
else (
220-
"You can COMPILE FROM SOURCE as mentioned here:\n"
232+
"You have two options:\n"
233+
"1. COMPILE FROM SOURCE as mentioned here:\n"
221234
" https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n"
235+
"2. Use BNB_ROCM_VERSION to specify a DIFFERENT ROCm version from the detected one, matching the version the library was built with.\n\n"
222236
)
223237
)
224238

bitsandbytes/diagnostics/cuda.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,21 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
135135
def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
136136
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")
137137

138+
rocm_override = os.environ.get("BNB_ROCM_VERSION")
139+
if rocm_override:
140+
print(f"BNB_ROCM_VERSION override: {rocm_override}")
141+
138142
binary_path = get_cuda_bnb_library_path(cuda_specs)
139143
if not binary_path.exists():
140144
print_dedented(
141145
f"""
142-
Library not found: {binary_path}.
143-
Maybe you need to compile it from source? If you compiled from source, check that ROCm version
144-
in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version
145-
and rebuild bitsandbytes.
146-
""",
146+
Library not found: {binary_path}.
147+
Maybe you need to compile it from source? If you compiled from source, check that ROCm version
148+
in PyTorch Settings matches your ROCm install. If not, you can either:
149+
1. Reinstall PyTorch for your ROCm version and rebuild bitsandbytes.
150+
2. Set BNB_ROCM_VERSION to match the version the library was built with.
151+
For example: export BNB_ROCM_VERSION=72
152+
""",
147153
)
148154

149155
hip_major, hip_minor = cuda_specs.cuda_version_tuple
@@ -192,22 +198,26 @@ def _print_cuda_runtime_diagnostics() -> None:
192198
def _print_hip_runtime_diagnostics() -> None:
193199
cudart_paths = list(find_cudart_libraries())
194200
if not cudart_paths:
195-
print("WARNING! ROCm runtime files not found in any environmental path.")
201+
print("ROCm SETUP: WARNING! ROCm runtime files not found in any environmental path.")
196202
elif len(cudart_paths) > 1:
197203
print_dedented(
198204
f"""
199205
Found duplicate ROCm runtime files (see below).
200206
201207
We select the PyTorch default ROCm runtime, which is {torch.version.hip},
202208
but this might mismatch with the ROCm version that is needed for bitsandbytes.
209+
To override this behavior set the `BNB_ROCM_VERSION=<version string, e.g. 72>` environmental variable.
210+
211+
For example, if you want to use the ROCm version 7.2,
212+
BNB_ROCM_VERSION=72 python ...
203213
204-
To resolve it, install PyTorch built for the ROCm version you want to use
214+
OR set the environmental variable in your .bashrc:
215+
export BNB_ROCM_VERSION=72
205216
206-
and set LD_LIBRARY_PATH to your ROCm install path, e.g.
207-
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib,
217+
In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g.
218+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-7.2.0/lib,
208219
""",
209220
)
210-
211221
for pth in cudart_paths:
212222
print(f"* Found ROCm runtime at: {pth}")
213223

tests/test_cuda_setup_evaluator.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,49 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
2424
monkeypatch.setenv("BNB_CUDA_VERSION", "110")
2525
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
2626
assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning?
27+
28+
29+
# Simulates torch+rocm7.0 (PyTorch bundled ROCm) on a system with ROCm 7.2
30+
@pytest.fixture
31+
def rocm70_spec() -> CUDASpecs:
32+
return CUDASpecs(
33+
cuda_version_string="70", # from torch.version.hip == "7.0.x"
34+
highest_compute_capability=(0, 0), # unused for ROCm library path resolution
35+
cuda_version_tuple=(7, 0),
36+
)
37+
38+
39+
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
40+
def test_get_rocm_bnb_library_path(monkeypatch, rocm70_spec):
41+
"""Without override, library path uses PyTorch's ROCm 7.0 version."""
42+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
43+
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
44+
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm70"
45+
46+
47+
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
48+
def test_get_rocm_bnb_library_path_override(monkeypatch, rocm70_spec, caplog):
49+
"""BNB_ROCM_VERSION=72 overrides to load the ROCm 7.2 library instead of 7.0."""
50+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
51+
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
52+
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
53+
assert "BNB_ROCM_VERSION" in caplog.text
54+
55+
56+
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
57+
def test_get_rocm_bnb_library_path_rejects_cuda_override(monkeypatch, rocm70_spec):
58+
"""BNB_CUDA_VERSION should be rejected on ROCm with a helpful error."""
59+
monkeypatch.delenv("BNB_ROCM_VERSION", raising=False)
60+
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
61+
with pytest.raises(RuntimeError, match=r"BNB_CUDA_VERSION.*detected for ROCm"):
62+
get_cuda_bnb_library_path(rocm70_spec)
63+
64+
65+
@pytest.mark.skipif(not HIP_ENVIRONMENT, reason="this test is only supported on ROCm")
66+
def test_get_rocm_bnb_library_path_rocm_override_takes_priority(monkeypatch, rocm70_spec, caplog):
67+
"""When both are set, BNB_ROCM_VERSION wins if HIP_ENVIRONMENT is True."""
68+
monkeypatch.setenv("BNB_ROCM_VERSION", "72")
69+
monkeypatch.setenv("BNB_CUDA_VERSION", "72")
70+
assert get_cuda_bnb_library_path(rocm70_spec).stem == "libbitsandbytes_rocm72"
71+
assert "BNB_ROCM_VERSION" in caplog.text
72+
assert "BNB_CUDA_VERSION" not in caplog.text

0 commit comments

Comments
 (0)