Skip to content

Commit b3b3bcc

Browse files
authored
Adding support for building for AMD on Windows (#1846)
* Adding support for building for AMD on Windows * Simplify building AMD on Windows by auto-detecting compiler and GPU arch * Standardize Windows detection to use platform.system() * Made some of the suggested changes * Reverted return to previous behavior * Changed to use find_path for lib path
1 parent 746cd64 commit b3b3bcc

File tree

3 files changed

+82
-18
lines changed

3 files changed

+82
-18
lines changed

CMakeLists.txt

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,32 @@
1616
# libbitsandbytes_rocm70.so even if the system has ROCm 7.2.
1717
cmake_minimum_required(VERSION 3.22.1)
1818

19+
# On Windows with HIP backend, auto-detect compilers from ROCM_PATH before project()
20+
if(WIN32 AND COMPUTE_BACKEND STREQUAL "hip")
21+
if(DEFINED ENV{ROCM_PATH})
22+
set(ROCM_PATH $ENV{ROCM_PATH})
23+
endif()
24+
if(ROCM_PATH AND NOT DEFINED CMAKE_CXX_COMPILER)
25+
set(CMAKE_CXX_COMPILER "${ROCM_PATH}/lib/llvm/bin/clang++.exe")
26+
endif()
27+
if(ROCM_PATH AND NOT DEFINED CMAKE_HIP_COMPILER)
28+
set(CMAKE_HIP_COMPILER "${ROCM_PATH}/lib/llvm/bin/clang++.exe")
29+
endif()
30+
# On Windows, the HIP compiler needs explicit paths to find device libraries.
31+
if(ROCM_PATH)
32+
find_path(ROCM_DEVICE_LIB_PATH
33+
NAMES oclc_abi_version_400.bc ocml.bc
34+
PATHS "${ROCM_PATH}/amdgcn/bitcode"
35+
"${ROCM_PATH}/lib/llvm/amdgcn/bitcode"
36+
NO_DEFAULT_PATH
37+
)
38+
set(CMAKE_HIP_FLAGS "--rocm-path=${ROCM_PATH}")
39+
if(ROCM_DEVICE_LIB_PATH)
40+
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --rocm-device-lib-path=${ROCM_DEVICE_LIB_PATH}")
41+
endif()
42+
endif()
43+
endif()
44+
1945
project(bitsandbytes LANGUAGES CXX)
2046

2147
# If run without specifying a build type, default to using the Release configuration:
@@ -204,17 +230,18 @@ if(BUILD_CUDA)
204230
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
205231
add_compile_definitions(BUILD_CUDA)
206232
elseif(BUILD_HIP)
207-
enable_language(HIP)
208-
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
233+
# Set target architectures before enable_language(HIP), which would otherwise
234+
# auto-detect a single GPU and override the defaults.
209235
if(DEFINED BNB_ROCM_ARCH)
210236
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
211-
else()
212-
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
213-
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100;gfx1101;gfx1150;gfx1151;gfx1200;gfx1201")
214-
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
215-
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
216-
endif()
237+
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
238+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
239+
elseif(NOT CMAKE_HIP_ARCHITECTURES)
240+
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100;gfx1101;gfx1150;gfx1151;gfx1200;gfx1201")
217241
endif()
242+
243+
enable_language(HIP)
244+
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
218245
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
219246

220247
list(APPEND SRC_FILES ${HIP_FILES})
@@ -275,6 +302,8 @@ endif()
275302
if(WIN32)
276303
# Export all symbols
277304
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
305+
# Prevent Windows SDK min/max macros from conflicting with std::min/std::max
306+
add_compile_definitions(NOMINMAX)
278307
endif()
279308

280309
if(MSVC)
@@ -327,10 +356,11 @@ if(BUILD_CUDA)
327356
)
328357
endif()
329358
if(BUILD_HIP)
330-
if(NOT DEFINED ENV{ROCM_PATH})
331-
set(ROCM_PATH /opt/rocm)
332-
else()
359+
# Determine ROCM_PATH from environment variable, fallback to /opt/rocm on Linux
360+
if(DEFINED ENV{ROCM_PATH})
333361
set(ROCM_PATH $ENV{ROCM_PATH})
362+
else()
363+
set(ROCM_PATH /opt/rocm)
334364
endif()
335365
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
336366
macro(find_package_and_print_version PACKAGE_NAME)
@@ -342,14 +372,23 @@ if(BUILD_HIP)
342372
find_package_and_print_version(hipsparse REQUIRED)
343373

344374
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
345-
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
346-
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
347-
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
375+
## On Windows, we need to link amdhip64 explicitly
376+
if(NOT WIN32)
377+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
378+
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
379+
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
380+
endif()
348381

349382
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
350383
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
351384
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
352385

386+
# On Windows, rocblas is not pulled in transitively by roc::hipblas
387+
# and is needed because ops_hip.cuh uses rocblas_handle directly.
388+
if(WIN32)
389+
target_link_libraries(bitsandbytes PUBLIC rocblas)
390+
endif()
391+
353392
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
354393
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
355394
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)

bitsandbytes/cuda_specs.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
from functools import lru_cache
33
import logging
4+
import platform
45
import re
56
import subprocess
67
from typing import Optional
@@ -83,8 +84,16 @@ def get_rocm_gpu_arch() -> str:
8384
logger = logging.getLogger(__name__)
8485
try:
8586
if torch.version.hip:
86-
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
87-
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
87+
# On Windows, use hipinfo.exe; on Linux, use rocminfo
88+
if platform.system() == "Windows":
89+
cmd = ["hipinfo.exe"]
90+
arch_pattern = r"gcnArchName:\s+gfx([a-zA-Z\d]+)"
91+
else:
92+
cmd = ["rocminfo"]
93+
arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"
94+
95+
result = subprocess.run(cmd, capture_output=True, text=True)
96+
match = re.search(arch_pattern, result.stdout)
8897
if match:
8998
return "gfx" + match.group(1)
9099
else:
@@ -107,8 +116,17 @@ def get_rocm_warpsize() -> int:
107116
logger = logging.getLogger(__name__)
108117
try:
109118
if torch.version.hip:
110-
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
111-
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
119+
# On Windows, use hipinfo.exe; on Linux, use rocminfo
120+
if platform.system() == "Windows":
121+
cmd = ["hipinfo.exe"]
122+
# hipinfo.exe output format: "warpSize: 32" or "warpSize: 64"
123+
warp_pattern = r"warpSize:\s+(\d+)"
124+
else:
125+
cmd = ["rocminfo"]
126+
warp_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)"
127+
128+
result = subprocess.run(cmd, capture_output=True, text=True)
129+
match = re.search(warp_pattern, result.stdout)
112130
if match:
113131
return int(match.group(1))
114132
else:

csrc/ops_hip.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
#include <cstdint>
1212
#include <iostream>
1313
#include <stdio.h>
14+
15+
#ifdef _WIN32
16+
#include <io.h>
17+
#include <process.h>
18+
#include <windows.h>
19+
#else
1420
#include <unistd.h>
21+
#endif
1522

1623
#include <common.h>
1724
#include <functional>

0 commit comments

Comments
 (0)