@@ -25,14 +25,15 @@ endif()
2525# Define included source files
2626set (CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
2727set (CUDA_FILES csrc/ops.cu csrc/kernels.cu)
28+ set (HIP_FILES csrc/ops.hip csrc/kernels.hip)
2829set (MPS_FILES csrc/mps_ops.mm)
2930set (METAL_FILES csrc/mps_kernels.metal)
3031set (XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
3132# C++ sources are always included
3233list (APPEND SRC_FILES ${CPP_FILES} )
3334
34- set (COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps, xpu)" )
35- set_property (CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps xpu)
35+ set (COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)" )
36+ set_property (CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
3637option (PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF )
3738
3839if (APPLE )
@@ -48,12 +49,21 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4849 message (FATAL_ERROR "CUDA is not supported on macOS" )
4950 endif ()
5051 set (BUILD_CUDA ON )
52+ set (BUILD_HIP OFF )
53+ set (BUILD_MPS OFF )
54+ elseif (${COMPUTE_BACKEND} STREQUAL "hip" )
55+ if (APPLE )
56+ message (FATAL_ERROR "HIP is not supported on macOS" )
57+ endif ()
58+ set (BUILD_CUDA OFF )
59+ set (BUILD_HIP ON )
5160 set (BUILD_MPS OFF )
5261elseif (${COMPUTE_BACKEND} STREQUAL "mps" )
5362 if (NOT APPLE )
5463 message (FATAL_ERROR "MPS is only supported on macOS" )
5564 endif ()
5665 set (BUILD_CUDA OFF )
66+ set (BUILD_HIP OFF )
5767 set (BUILD_MPS ON )
5868elseif (${COMPUTE_BACKEND} STREQUAL "xpu" )
5969 if (APPLE )
@@ -64,6 +74,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
6474 set (BUILD_XPU ON )
6575else ()
6676 set (BUILD_CUDA OFF )
77+ set (BUILD_HIP OFF )
6778 set (BUILD_MPS OFF )
6879 set (BUILD_XPU OFF )
6980endif ()
@@ -169,6 +180,33 @@ if(BUILD_CUDA)
169180
170181 string (APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT} " )
171182 add_compile_definitions (BUILD_CUDA )
183+ elseif (BUILD_HIP)
184+ enable_language (HIP )
185+ message (STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER} " )
186+ if (DEFINED BNB_ROCM_ARCH)
187+ set (CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH} )
188+ else ()
189+ if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
190+ set (CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100" )
191+ elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
192+ set (CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS} )
193+ endif ()
194+ endif ()
195+ message (STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES} " )
196+
197+ list (APPEND SRC_FILES ${HIP_FILES} )
198+
199+ string (APPEND BNB_OUTPUT_NAME "_rocm" )
200+
201+ # get hip version
202+ execute_process (COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION )
203+ string (REGEX MATCH "[0-9]+\\ .[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION} " )
204+ string (REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION} " )
205+
206+ string (APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT} " )
207+ add_compile_definitions (__HIP_PLATFORM_AMD__ )
208+ add_compile_definitions (__HIP_PLATFORM_HCC__ )
209+ add_compile_definitions (BUILD_HIP )
172210elseif (BUILD_MPS)
173211 if (NOT APPLE )
174212 message (FATAL_ERROR "MPS is only supported on macOS" )
@@ -223,6 +261,41 @@ if(BUILD_CUDA)
223261 CUDA_SEPARABLE_COMPILATION ON
224262 )
225263endif ()
264+ if (BUILD_HIP)
265+ if (NOT DEFINED ENV{ROCM_PATH})
266+ set (ROCM_PATH /opt/rocm)
267+ else ()
268+ set (ROCM_PATH $ENV{ROCM_PATH} )
269+ endif ()
270+ list (APPEND CMAKE_PREFIX_PATH ${ROCM_PATH} )
271+ macro (find_package_and_print_version PACKAGE_NAME )
272+ find_package ("${PACKAGE_NAME} " ${ARGN} )
273+ message ("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME} _VERSION}" )
274+ endmacro ()
275+ find_package_and_print_version (hipblas REQUIRED )
276+ find_package_and_print_version (hiprand REQUIRED )
277+ find_package_and_print_version (hipsparse REQUIRED )
278+
279+ ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
280+ set_target_properties (hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "" )
281+ set_target_properties (hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "" )
282+ set (CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "" )
283+
284+ target_include_directories (bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR } ${CMAKE_SOURCE_DIR } /include ${ROCM_PATH} /include /include )
285+ target_link_directories (bitsandbytes PRIVATE ${ROCM_PATH} /lib /lib )
286+ target_link_libraries (bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse )
287+
288+ target_compile_definitions (bitsandbytes PUBLIC BNB_USE_HIP )
289+ set_source_files_properties (${HIP_FILES} PROPERTIES LANGUAGE HIP)
290+ set_target_properties (bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
291+
292+ if (HIP_VERSION VERSION_LESS "6.1" )
293+ target_compile_definitions (bitsandbytes PUBLIC NO_HIPBLASLT )
294+ else ()
295+ find_package (hipblaslt )
296+ target_link_libraries (bitsandbytes PUBLIC roc::hipblaslt )
297+ endif ()
298+ endif ()
226299if (BUILD_MPS)
227300 add_dependencies (bitsandbytes metallib )
228301 target_link_libraries (bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph" )
0 commit comments