11{
22 lib ,
3- torch ,
3+ stdenv ,
44 symlinkJoin ,
55 buildPythonPackage ,
66 fetchFromGitHub ,
7+
78 cmake ,
89
910 # build-system
1011 scikit-build-core ,
1112 setuptools ,
1213
1314 # dependencies
15+ torch ,
1416 scipy ,
17+ trove-classifiers ,
18+
19+ cudaSupport ? torch . cudaSupport ,
20+ cudaPackages ? torch . cudaPackages ,
21+ rocmSupport ? torch . rocmSupport ,
22+ rocmPackages ? torch . rocmPackages ,
23+
24+ rocmGpuTargets ? rocmPackages . clr . localGpuTargets or rocmPackages . clr . gpuTargets ,
1525} :
1626
1727let
1828 pname = "bitsandbytes" ;
19- version = "0.47.0" ;
29+ version = "0.48.1" ;
30+
31+ brokenConditions = lib . attrsets . filterAttrs ( _ : cond : cond ) {
32+ "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport ;
33+ "CUDA is not targeting Linux" = cudaSupport && ! stdenv . hostPlatform . isLinux ;
34+ } ;
2035
21- inherit ( torch ) cudaPackages cudaSupport ;
2236 inherit ( cudaPackages ) cudaMajorMinorVersion ;
37+ rocmMajorMinorVersion = lib . versions . majorMinor rocmPackages . rocm-core . version ;
2338
2439 cudaMajorMinorVersionString = lib . replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion ;
40+ rocmMajorMinorVersionString = lib . replaceStrings [ "." ] [ "" ] rocmMajorMinorVersion ;
2541
2642 # NOTE: torchvision doesn't use cudnn; torch does!
2743 # For this reason it is not included.
@@ -62,41 +78,87 @@ buildPythonPackage {
6278 owner = "bitsandbytes-foundation" ;
6379 repo = "bitsandbytes" ;
6480 tag = version ;
65- hash = "sha256-iUAeiNbPa3Q5jJ4lK2G0WvTKuipb0zO1mNe+wcRdnqs =" ;
81+ hash = "sha256-OkhWv5Mb/cnWJteCXvDEkWQvK+QK26YQex39yWIezrQ =" ;
6682 } ;
6783
84+ patches = [
85+ ./find-rocm-deps-with-cmake.patch
86+ ] ;
87+
6888 # By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
6989 # When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
7090 # Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
7191 # Also, hardcode the path to the previously built library instead of relying on
7292 # `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
7393 #
74- # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
94+ # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for CUDA version `xx.y`
95+ # and `libbitsandbytes_rocmxxy` for ROCm version `xx.y`
7596 # This upstream convention could change at some point and thus break the following patch.
76- postPatch = lib . optionalString cudaSupport ''
77- substituteInPlace bitsandbytes/cextension.py \
78- --replace-fail "if cuda_specs:" "if True:" \
79- --replace-fail \
80- "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
81- "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${ cudaMajorMinorVersionString } .so'"
82- '' ;
97+ postPatch = (
98+ let
99+ prefix = if cudaSupport then "cuda" else "rocm" ;
100+ majorMinorVersionString =
101+ if cudaSupport then cudaMajorMinorVersionString else rocmMajorMinorVersionString ;
102+ in
103+ lib . optionalString ( cudaSupport || rocmSupport ) ''
104+ substituteInPlace bitsandbytes/cextension.py \
105+ --replace-fail "if cuda_specs:" "if True:" \
106+ --replace-fail \
107+ "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
108+ "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_${ prefix } ${ majorMinorVersionString } .so'"
109+ ''
110+ ) ;
83111
84112 nativeBuildInputs = [
85113 cmake
86114 ]
87115 ++ lib . optionals cudaSupport [
88116 cudaPackages . cuda_nvcc
117+ ]
118+ ++ lib . optionals rocmSupport [
119+ rocmPackages . clr
89120 ] ;
90121
91122 build-system = [
92123 scikit-build-core
93124 setuptools
94125 ] ;
95126
96- buildInputs = lib . optionals cudaSupport [ cuda-redist ] ;
127+ buildInputs =
128+ lib . optional cudaSupport cuda-redist
129+ ++ lib . optionals rocmSupport (
130+ with rocmPackages ;
131+ [
132+ rocm-device-libs
133+ hipblas
134+ rocm-comgr
135+ rocm-runtime
136+ hiprand
137+ rocrand
138+ hipsparse
139+ hipblaslt
140+ rocblas
141+ hipcub
142+ rocprim
143+ ]
144+ ) ;
97145
98146 cmakeFlags = [
99- ( lib . cmakeFeature "COMPUTE_BACKEND" ( if cudaSupport then "cuda" else "cpu" ) )
147+ ( lib . cmakeFeature "COMPUTE_BACKEND" (
148+ if cudaSupport then
149+ "cuda"
150+ else if rocmSupport then
151+ "hip"
152+ else
153+ "cpu"
154+ ) )
155+ ]
156+ ++ lib . optionals rocmSupport [
157+ # ends up using g++ to build some files it shouldn't
158+ ( lib . cmakeFeature "CMAKE_C_COMPILER" "amdclang" )
159+ ( lib . cmakeFeature "CMAKE_CXX_COMPILER" "amdclang++" )
160+
161+ ( lib . cmakeFeature "CMAKE_HIP_ARCHITECTURES" ( builtins . concatStringsSep ";" rocmGpuTargets ) )
100162 ] ;
101163 CUDA_HOME = lib . optionalString cudaSupport "${ cuda-native-redist } " ;
102164 NVCC_PREPEND_FLAGS = lib . optionals cudaSupport [
@@ -112,17 +174,31 @@ buildPythonPackage {
112174 dependencies = [
113175 scipy
114176 torch
177+ trove-classifiers
115178 ] ;
116179
117180 doCheck = false ; # tests require CUDA and also GPU access
118181
119182 pythonImportsCheck = [ "bitsandbytes" ] ;
120183
184+ passthru = {
185+ inherit
186+ cudaSupport
187+ cudaPackages
188+ rocmSupport
189+ rocmPackages
190+ brokenConditions # To help debug when a package is broken due to CUDA support
191+ ;
192+ } ;
193+
121194 meta = {
122195 description = "8-bit CUDA functions for PyTorch" ;
123196 homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes" ;
124197 changelog = "https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/${ version } " ;
125198 license = lib . licenses . mit ;
126- maintainers = with lib . maintainers ; [ bcdarwin ] ;
199+ maintainers = with lib . maintainers ; [
200+ bcdarwin
201+ jk
202+ ] ;
127203 } ;
128204}
0 commit comments