Skip to content

Commit c86eb5f

Browse files
python3Packages.bitsandbytes: 0.47.0 -> 0.48.1, support rocm (#443210)
2 parents 5c51fb8 + 480f7a5 commit c86eb5f

2 files changed

Lines changed: 114 additions & 15 deletions

File tree

pkgs/development/python-modules/bitsandbytes/default.nix

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,43 @@
11
{
22
lib,
3-
torch,
3+
stdenv,
44
symlinkJoin,
55
buildPythonPackage,
66
fetchFromGitHub,
7+
78
cmake,
89

910
# build-system
1011
scikit-build-core,
1112
setuptools,
1213

1314
# dependencies
15+
torch,
1416
scipy,
17+
trove-classifiers,
18+
19+
cudaSupport ? torch.cudaSupport,
20+
cudaPackages ? torch.cudaPackages,
21+
rocmSupport ? torch.rocmSupport,
22+
rocmPackages ? torch.rocmPackages,
23+
24+
rocmGpuTargets ? rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets,
1525
}:
1626

1727
let
1828
pname = "bitsandbytes";
19-
version = "0.47.0";
29+
version = "0.48.1";
30+
31+
brokenConditions = lib.attrsets.filterAttrs (_: cond: cond) {
32+
"CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
33+
"CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
34+
};
2035

21-
inherit (torch) cudaPackages cudaSupport;
2236
inherit (cudaPackages) cudaMajorMinorVersion;
37+
rocmMajorMinorVersion = lib.versions.majorMinor rocmPackages.rocm-core.version;
2338

2439
cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion;
40+
rocmMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] rocmMajorMinorVersion;
2541

2642
# NOTE: torchvision doesn't use cudnn; torch does!
2743
# For this reason it is not included.
@@ -62,41 +78,87 @@ buildPythonPackage {
6278
owner = "bitsandbytes-foundation";
6379
repo = "bitsandbytes";
6480
tag = version;
65-
hash = "sha256-iUAeiNbPa3Q5jJ4lK2G0WvTKuipb0zO1mNe+wcRdnqs=";
81+
hash = "sha256-OkhWv5Mb/cnWJteCXvDEkWQvK+QK26YQex39yWIezrQ=";
6682
};
6783

84+
patches = [
85+
./find-rocm-deps-with-cmake.patch
86+
];
87+
6888
# By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
6989
# When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
7090
# Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
7191
# Also, hardcode the path to the previously built library instead of relying on
7292
# `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
7393
#
74-
# WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
94+
# WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for CUDA version `xx.y`
95+
# and `libbitsandbytes_rocmxxy` for ROCm version `xx.y`
7596
# This upstream convention could change at some point and thus break the following patch.
76-
postPatch = lib.optionalString cudaSupport ''
77-
substituteInPlace bitsandbytes/cextension.py \
78-
--replace-fail "if cuda_specs:" "if True:" \
79-
--replace-fail \
80-
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
81-
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'"
82-
'';
97+
postPatch = (
98+
let
99+
prefix = if cudaSupport then "cuda" else "rocm";
100+
majorMinorVersionString =
101+
if cudaSupport then cudaMajorMinorVersionString else rocmMajorMinorVersionString;
102+
in
103+
lib.optionalString (cudaSupport || rocmSupport) ''
104+
substituteInPlace bitsandbytes/cextension.py \
105+
--replace-fail "if cuda_specs:" "if True:" \
106+
--replace-fail \
107+
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
108+
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_${prefix}${majorMinorVersionString}.so'"
109+
''
110+
);
83111

84112
nativeBuildInputs = [
85113
cmake
86114
]
87115
++ lib.optionals cudaSupport [
88116
cudaPackages.cuda_nvcc
117+
]
118+
++ lib.optionals rocmSupport [
119+
rocmPackages.clr
89120
];
90121

91122
build-system = [
92123
scikit-build-core
93124
setuptools
94125
];
95126

96-
buildInputs = lib.optionals cudaSupport [ cuda-redist ];
127+
buildInputs =
128+
lib.optional cudaSupport cuda-redist
129+
++ lib.optionals rocmSupport (
130+
with rocmPackages;
131+
[
132+
rocm-device-libs
133+
hipblas
134+
rocm-comgr
135+
rocm-runtime
136+
hiprand
137+
rocrand
138+
hipsparse
139+
hipblaslt
140+
rocblas
141+
hipcub
142+
rocprim
143+
]
144+
);
97145

98146
cmakeFlags = [
99-
(lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu"))
147+
(lib.cmakeFeature "COMPUTE_BACKEND" (
148+
if cudaSupport then
149+
"cuda"
150+
else if rocmSupport then
151+
"hip"
152+
else
153+
"cpu"
154+
))
155+
]
156+
++ lib.optionals rocmSupport [
157+
# ends up using g++ to build some files it shouldn't
158+
(lib.cmakeFeature "CMAKE_C_COMPILER" "amdclang")
159+
(lib.cmakeFeature "CMAKE_CXX_COMPILER" "amdclang++")
160+
161+
(lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmGpuTargets))
100162
];
101163
CUDA_HOME = lib.optionalString cudaSupport "${cuda-native-redist}";
102164
NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [
@@ -112,17 +174,31 @@ buildPythonPackage {
112174
dependencies = [
113175
scipy
114176
torch
177+
trove-classifiers
115178
];
116179

117180
doCheck = false; # tests require CUDA and also GPU access
118181

119182
pythonImportsCheck = [ "bitsandbytes" ];
120183

184+
passthru = {
185+
inherit
186+
cudaSupport
187+
cudaPackages
188+
rocmSupport
189+
rocmPackages
190+
brokenConditions # To help debug when a package is broken due to CUDA support
191+
;
192+
};
193+
121194
meta = {
122195
description = "8-bit CUDA functions for PyTorch";
123196
homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes";
124197
changelog = "https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/${version}";
125198
license = lib.licenses.mit;
126-
maintainers = with lib.maintainers; [ bcdarwin ];
199+
maintainers = with lib.maintainers; [
200+
bcdarwin
201+
jk
202+
];
127203
};
128204
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
diff --git a/CMakeLists.txt b/CMakeLists.txt
2+
index 9c133e0..1aa8a53 100644
3+
--- a/CMakeLists.txt
4+
+++ b/CMakeLists.txt
5+
@@ -285,6 +285,9 @@ if(BUILD_HIP)
6+
find_package_and_print_version(hipblas REQUIRED)
7+
find_package_and_print_version(hiprand REQUIRED)
8+
find_package_and_print_version(hipsparse REQUIRED)
9+
+ find_package_and_print_version(rocblas REQUIRED)
10+
+ find_package_and_print_version(hip REQUIRED)
11+
+ find_package_and_print_version(hipcub REQUIRED)
12+
13+
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
14+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
15+
@@ -293,7 +296,7 @@ if(BUILD_HIP)
16+
17+
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
18+
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
19+
- target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
20+
+ target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse roc::rocblas hip::device hip::hipcub)
21+
22+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
23+
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)

0 commit comments

Comments
 (0)