11{
22 lib ,
3- torch ,
3+ stdenv ,
44 symlinkJoin ,
55 buildPythonPackage ,
66 fetchFromGitHub ,
7+
78 cmake ,
89
910 # build-system
1011 scikit-build-core ,
1112 setuptools ,
1213
1314 # dependencies
15+ torch ,
1416 scipy ,
1517 trove-classifiers ,
18+
19+ cudaSupport ? torch . cudaSupport ,
20+ cudaPackages ? torch . cudaPackages ,
21+ rocmSupport ? torch . rocmSupport ,
22+ rocmPackages ? torch . rocmPackages ,
23+
24+ rocmGpuTargets ? rocmPackages . clr . localGpuTargets or rocmPackages . clr . gpuTargets ,
1625} :
1726
1827let
1928 pname = "bitsandbytes" ;
2029 version = "0.48.1" ;
2130
22- inherit ( torch ) cudaPackages cudaSupport ;
31+ brokenConditions = lib . attrsets . filterAttrs ( _ : cond : cond ) {
32+ "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport ;
33+ "CUDA is not targeting Linux" = cudaSupport && ! stdenv . hostPlatform . isLinux ;
34+ } ;
35+
2336 inherit ( cudaPackages ) cudaMajorMinorVersion ;
37+ rocmMajorMinorVersion = lib . versions . majorMinor rocmPackages . rocm-core . version ;
2438
2539 cudaMajorMinorVersionString = lib . replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion ;
40+ rocmMajorMinorVersionString = lib . replaceStrings [ "." ] [ "" ] rocmMajorMinorVersion ;
2641
2742 # NOTE: torchvision doesn't use cudnn; torch does!
2843 # For this reason it is not included.
@@ -66,38 +81,84 @@ buildPythonPackage {
6681 hash = "sha256-OkhWv5Mb/cnWJteCXvDEkWQvK+QK26YQex39yWIezrQ=" ;
6782 } ;
6883
84+ patches = [
85+ ./find-rocm-deps-with-cmake.patch
86+ ] ;
87+
6988 # By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
7089 # When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
7190 # Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
7291 # Also, hardcode the path to the previously built library instead of relying on
7392 # `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
7493 #
75- # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
94+ # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for CUDA version `xx.y`
95+ # and `libbitsandbytes_rocmxxy` for ROCm version `xx.y`
7696 # This upstream convention could change at some point and thus break the following patch.
77- postPatch = lib . optionalString cudaSupport ''
78- substituteInPlace bitsandbytes/cextension.py \
79- --replace-fail "if cuda_specs:" "if True:" \
80- --replace-fail \
81- "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
82- "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${ cudaMajorMinorVersionString } .so'"
83- '' ;
97+ postPatch = (
98+ let
99+ prefix = if cudaSupport then "cuda" else "rocm" ;
100+ majorMinorVersionString =
101+ if cudaSupport then cudaMajorMinorVersionString else rocmMajorMinorVersionString ;
102+ in
103+ lib . optionalString ( cudaSupport || rocmSupport ) ''
104+ substituteInPlace bitsandbytes/cextension.py \
105+ --replace-fail "if cuda_specs:" "if True:" \
106+ --replace-fail \
107+ "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
108+ "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_${ prefix } ${ majorMinorVersionString } .so'"
109+ ''
110+ ) ;
84111
85112 nativeBuildInputs = [
86113 cmake
87114 ]
88115 ++ lib . optionals cudaSupport [
89116 cudaPackages . cuda_nvcc
117+ ]
118+ ++ lib . optionals rocmSupport [
119+ rocmPackages . clr
90120 ] ;
91121
92122 build-system = [
93123 scikit-build-core
94124 setuptools
95125 ] ;
96126
97- buildInputs = lib . optionals cudaSupport [ cuda-redist ] ;
127+ buildInputs =
128+ lib . optional cudaSupport cuda-redist
129+ ++ lib . optionals rocmSupport (
130+ with rocmPackages ;
131+ [
132+ rocm-device-libs
133+ hipblas
134+ rocm-comgr
135+ rocm-runtime
136+ hiprand
137+ rocrand
138+ hipsparse
139+ hipblaslt
140+ rocblas
141+ hipcub
142+ rocprim
143+ ]
144+ ) ;
98145
99146 cmakeFlags = [
100- ( lib . cmakeFeature "COMPUTE_BACKEND" ( if cudaSupport then "cuda" else "cpu" ) )
147+ ( lib . cmakeFeature "COMPUTE_BACKEND" (
148+ if cudaSupport then
149+ "cuda"
150+ else if rocmSupport then
151+ "hip"
152+ else
153+ "cpu"
154+ ) )
155+ ]
156+ ++ lib . optionals rocmSupport [
157+ # ends up using g++ to build some files it shouldn't
158+ ( lib . cmakeFeature "CMAKE_C_COMPILER" "amdclang" )
159+ ( lib . cmakeFeature "CMAKE_CXX_COMPILER" "amdclang++" )
160+
161+ ( lib . cmakeFeature "CMAKE_HIP_ARCHITECTURES" ( builtins . concatStringsSep ";" rocmGpuTargets ) )
101162 ] ;
102163 CUDA_HOME = lib . optionalString cudaSupport "${ cuda-native-redist } " ;
103164 NVCC_PREPEND_FLAGS = lib . optionals cudaSupport [
@@ -120,6 +181,16 @@ buildPythonPackage {
120181
121182 pythonImportsCheck = [ "bitsandbytes" ] ;
122183
184+ passthru = {
185+ inherit
186+ cudaSupport
187+ cudaPackages
188+ rocmSupport
189+ rocmPackages
190+ brokenConditions # To help debug when a package is broken due to CUDA support
191+ ;
192+ } ;
193+
123194 meta = {
124195 description = "8-bit CUDA functions for PyTorch" ;
125196 homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes" ;
0 commit comments