Skip to content

Commit fd35abb

Browse files
committed
python3Packages.bitsandbytes: support rocm
1 parent 1b9868c commit fd35abb

2 files changed

Lines changed: 106 additions & 12 deletions

File tree

pkgs/development/python-modules/bitsandbytes/default.nix

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,43 @@
11
{
22
lib,
3-
torch,
3+
stdenv,
44
symlinkJoin,
55
buildPythonPackage,
66
fetchFromGitHub,
7+
78
cmake,
89

910
# build-system
1011
scikit-build-core,
1112
setuptools,
1213

1314
# dependencies
15+
torch,
1416
scipy,
1517
trove-classifiers,
18+
19+
cudaSupport ? torch.cudaSupport,
20+
cudaPackages ? torch.cudaPackages,
21+
rocmSupport ? torch.rocmSupport,
22+
rocmPackages ? torch.rocmPackages,
23+
24+
rocmGpuTargets ? rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets,
1625
}:
1726

1827
let
1928
pname = "bitsandbytes";
2029
version = "0.48.1";
2130

22-
inherit (torch) cudaPackages cudaSupport;
31+
brokenConditions = lib.attrsets.filterAttrs (_: cond: cond) {
32+
"CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
33+
"CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
34+
};
35+
2336
inherit (cudaPackages) cudaMajorMinorVersion;
37+
rocmMajorMinorVersion = lib.versions.majorMinor rocmPackages.rocm-core.version;
2438

2539
cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion;
40+
rocmMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] rocmMajorMinorVersion;
2641

2742
# NOTE: torchvision doesn't use cudnn; torch does!
2843
# For this reason it is not included.
@@ -66,38 +81,84 @@ buildPythonPackage {
6681
hash = "sha256-OkhWv5Mb/cnWJteCXvDEkWQvK+QK26YQex39yWIezrQ=";
6782
};
6883

84+
patches = [
85+
./find-rocm-deps-with-cmake.patch
86+
];
87+
6988
# By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
7089
# When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
7190
# Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
7291
# Also, hardcode the path to the previously built library instead of relying on
7392
# `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
7493
#
75-
# WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
94+
# WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for CUDA version `xx.y`
95+
# and `libbitsandbytes_rocmxxy` for ROCm version `xx.y`
7696
# This upstream convention could change at some point and thus break the following patch.
77-
postPatch = lib.optionalString cudaSupport ''
78-
substituteInPlace bitsandbytes/cextension.py \
79-
--replace-fail "if cuda_specs:" "if True:" \
80-
--replace-fail \
81-
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
82-
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'"
83-
'';
97+
postPatch = (
98+
let
99+
prefix = if cudaSupport then "cuda" else "rocm";
100+
majorMinorVersionString =
101+
if cudaSupport then cudaMajorMinorVersionString else rocmMajorMinorVersionString;
102+
in
103+
lib.optionalString (cudaSupport || rocmSupport) ''
104+
substituteInPlace bitsandbytes/cextension.py \
105+
--replace-fail "if cuda_specs:" "if True:" \
106+
--replace-fail \
107+
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
108+
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_${prefix}${majorMinorVersionString}.so'"
109+
''
110+
);
84111

85112
nativeBuildInputs = [
86113
cmake
87114
]
88115
++ lib.optionals cudaSupport [
89116
cudaPackages.cuda_nvcc
117+
]
118+
++ lib.optionals rocmSupport [
119+
rocmPackages.clr
90120
];
91121

92122
build-system = [
93123
scikit-build-core
94124
setuptools
95125
];
96126

97-
buildInputs = lib.optionals cudaSupport [ cuda-redist ];
127+
buildInputs =
128+
lib.optional cudaSupport cuda-redist
129+
++ lib.optionals rocmSupport (
130+
with rocmPackages;
131+
[
132+
rocm-device-libs
133+
hipblas
134+
rocm-comgr
135+
rocm-runtime
136+
hiprand
137+
rocrand
138+
hipsparse
139+
hipblaslt
140+
rocblas
141+
hipcub
142+
rocprim
143+
]
144+
);
98145

99146
cmakeFlags = [
100-
(lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu"))
147+
(lib.cmakeFeature "COMPUTE_BACKEND" (
148+
if cudaSupport then
149+
"cuda"
150+
else if rocmSupport then
151+
"hip"
152+
else
153+
"cpu"
154+
))
155+
]
156+
++ lib.optionals rocmSupport [
157+
# ends up using g++ to build some files it shouldn't
158+
(lib.cmakeFeature "CMAKE_C_COMPILER" "amdclang")
159+
(lib.cmakeFeature "CMAKE_CXX_COMPILER" "amdclang++")
160+
161+
(lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmGpuTargets))
101162
];
102163
CUDA_HOME = lib.optionalString cudaSupport "${cuda-native-redist}";
103164
NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [
@@ -120,6 +181,16 @@ buildPythonPackage {
120181

121182
pythonImportsCheck = [ "bitsandbytes" ];
122183

184+
passthru = {
185+
inherit
186+
cudaSupport
187+
cudaPackages
188+
rocmSupport
189+
rocmPackages
190+
brokenConditions # To help debug when a package is broken due to CUDA support
191+
;
192+
};
193+
123194
meta = {
124195
description = "8-bit CUDA functions for PyTorch";
125196
homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes";
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
diff --git a/CMakeLists.txt b/CMakeLists.txt
2+
index 9c133e0..1aa8a53 100644
3+
--- a/CMakeLists.txt
4+
+++ b/CMakeLists.txt
5+
@@ -285,6 +285,9 @@ if(BUILD_HIP)
6+
find_package_and_print_version(hipblas REQUIRED)
7+
find_package_and_print_version(hiprand REQUIRED)
8+
find_package_and_print_version(hipsparse REQUIRED)
9+
+ find_package_and_print_version(rocblas REQUIRED)
10+
+ find_package_and_print_version(hip REQUIRED)
11+
+ find_package_and_print_version(hipcub REQUIRED)
12+
13+
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
14+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
15+
@@ -293,7 +296,7 @@ if(BUILD_HIP)
16+
17+
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
18+
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
19+
- target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
20+
+ target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse roc::rocblas hip::device hip::hipcub)
21+
22+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
23+
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)

0 commit comments

Comments
 (0)