Skip to content

Commit d5a0a90

Browse files
committed
Adding support for building for AMD on Windows
1 parent fe784a7 commit d5a0a90

File tree

4 files changed

+56
-8
lines changed

4 files changed

+56
-8
lines changed

CMakeLists.txt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ endif()
263263
if(WIN32)
264264
# Export all symbols
265265
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
266+
# Prevent Windows SDK min/max macros from conflicting with std::min/std::max
267+
add_compile_definitions(NOMINMAX)
266268
endif()
267269

268270
if(MSVC)
@@ -330,14 +332,22 @@ if(BUILD_HIP)
330332
find_package_and_print_version(hipsparse REQUIRED)
331333

332334
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
333-
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
334-
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
335-
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
335+
## On Windows, we need to link amdhip64 explicitly
336+
if(NOT WIN32)
337+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
338+
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
339+
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
340+
endif()
336341

337342
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
338343
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
339344
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
340345

346+
# On Windows, link the HIP runtime and rocblas directly
347+
if(WIN32)
348+
target_link_libraries(bitsandbytes PUBLIC amdhip64 rocblas)
349+
endif()
350+
341351
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
342352
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
343353
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)

bitsandbytes/cuda_specs.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
from functools import lru_cache
33
import logging
4+
import os
45
import re
56
import subprocess
67
from typing import Optional
@@ -83,10 +84,21 @@ def get_rocm_gpu_arch() -> str:
8384
logger = logging.getLogger(__name__)
8485
try:
8586
if torch.version.hip:
86-
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
87-
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
87+
# On Windows, use hipinfo.exe; on Linux, use rocminfo
88+
if os.name == "nt":
89+
cmd = ["hipinfo.exe"]
90+
arch_pattern = r"gcnArchName:\s+(gfx[a-zA-Z\d]+)"
91+
else:
92+
cmd = ["rocminfo"]
93+
arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"
94+
95+
result = subprocess.run(cmd, capture_output=True, text=True)
96+
match = re.search(arch_pattern, result.stdout)
8897
if match:
89-
return "gfx" + match.group(1)
98+
if os.name == "nt":
99+
return match.group(1)
100+
else:
101+
return "gfx" + match.group(1)
90102
else:
91103
return "unknown"
92104
else:
@@ -107,8 +119,17 @@ def get_rocm_warpsize() -> int:
107119
logger = logging.getLogger(__name__)
108120
try:
109121
if torch.version.hip:
110-
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
111-
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
122+
# On Windows, use hipinfo.exe; on Linux, use rocminfo
123+
if os.name == "nt":
124+
cmd = ["hipinfo.exe"]
125+
# hipinfo.exe output format: "warpSize: 32" or "warpSize: 64"
126+
warp_pattern = r"warpSize:\s+(\d+)"
127+
else:
128+
cmd = ["rocminfo"]
129+
warp_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)"
130+
131+
result = subprocess.run(cmd, capture_output=True, text=True)
132+
match = re.search(warp_pattern, result.stdout)
112133
if match:
113134
return int(match.group(1))
114135
else:

csrc/ops.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
#include <cstdint>
1111
#include <iostream>
1212
#include <stdio.h>
13+
#ifdef _WIN32
14+
#include <io.h>
15+
#include <process.h>
16+
#include <windows.h>
17+
#else
18+
#include <unistd.h>
19+
#endif
1320

1421
#include <common.h>
1522
#include <cublasLt.h>

csrc/ops_hip.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@
1111
#include <cstdint>
1212
#include <iostream>
1313
#include <stdio.h>
14+
15+
#ifdef _WIN32
16+
#ifndef NOMINMAX
17+
#define NOMINMAX
18+
#endif
19+
#include <io.h>
20+
#include <process.h>
21+
#include <windows.h>
22+
#else
1423
#include <unistd.h>
24+
#endif
1525

1626
#include <common.h>
1727
#include <functional>

0 commit comments

Comments
 (0)