diff --git a/pkgs/development/python-modules/bitsandbytes/default.nix b/pkgs/development/python-modules/bitsandbytes/default.nix index d1767fa005beb..125386dec9b65 100644 --- a/pkgs/development/python-modules/bitsandbytes/default.nix +++ b/pkgs/development/python-modules/bitsandbytes/default.nix @@ -1,9 +1,10 @@ { lib, - torch, + stdenv, symlinkJoin, buildPythonPackage, fetchFromGitHub, + cmake, # build-system @@ -11,17 +12,32 @@ setuptools, # dependencies + torch, scipy, + trove-classifiers, + + cudaSupport ? torch.cudaSupport, + cudaPackages ? torch.cudaPackages, + rocmSupport ? torch.rocmSupport, + rocmPackages ? torch.rocmPackages, + + rocmGpuTargets ? rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets, }: let pname = "bitsandbytes"; - version = "0.47.0"; + version = "0.48.1"; + + brokenConditions = lib.attrsets.filterAttrs (_: cond: cond) { + "CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport; + "CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux; + }; - inherit (torch) cudaPackages cudaSupport; inherit (cudaPackages) cudaMajorMinorVersion; + rocmMajorMinorVersion = lib.versions.majorMinor rocmPackages.rocm-core.version; cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion; + rocmMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] rocmMajorMinorVersion; # NOTE: torchvision doesn't use cudnn; torch does! # For this reason it is not included. @@ -62,30 +78,45 @@ buildPythonPackage { owner = "bitsandbytes-foundation"; repo = "bitsandbytes"; tag = version; - hash = "sha256-iUAeiNbPa3Q5jJ4lK2G0WvTKuipb0zO1mNe+wcRdnqs="; + hash = "sha256-OkhWv5Mb/cnWJteCXvDEkWQvK+QK26YQex39yWIezrQ="; }; + patches = [ + ./find-rocm-deps-with-cmake.patch + ]; + # By default, which library is loaded depends on the result of `torch.cuda.is_available()`. # When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally. # Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not. # Also, hardcode the path to the previously built library instead of relying on # `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too. # - # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`. + # WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for CUDA version `xx.y` + # and `libbitsandbytes_rocmxxy` for ROCm version `xx.y` # This upstream convention could change at some point and thus break the following patch. - postPatch = lib.optionalString cudaSupport '' - substituteInPlace bitsandbytes/cextension.py \ - --replace-fail "if cuda_specs:" "if True:" \ - --replace-fail \ - "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \ - "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'" - ''; + postPatch = ( + let + prefix = if cudaSupport then "cuda" else "rocm"; + majorMinorVersionString = + if cudaSupport then cudaMajorMinorVersionString else rocmMajorMinorVersionString; + in + lib.optionalString (cudaSupport || rocmSupport) '' + substituteInPlace bitsandbytes/cextension.py \ + --replace-fail "if cuda_specs:" "if True:" \ + --replace-fail \ + "cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \ + "cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_${prefix}${majorMinorVersionString}.so'" + '' + ); nativeBuildInputs = [ cmake ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc + ] + ++ lib.optionals rocmSupport [ + rocmPackages.clr ]; build-system = [ @@ -93,10 +124,41 @@ buildPythonPackage { setuptools ]; - buildInputs = lib.optionals cudaSupport [ cuda-redist ]; + buildInputs = + lib.optional cudaSupport cuda-redist + ++ lib.optionals rocmSupport ( + with rocmPackages; + [ + rocm-device-libs + hipblas + rocm-comgr + rocm-runtime + hiprand + rocrand + hipsparse + hipblaslt + rocblas + hipcub + rocprim + ] + ); cmakeFlags = [ - (lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu")) + (lib.cmakeFeature "COMPUTE_BACKEND" ( + if cudaSupport then + "cuda" + else if rocmSupport then + "hip" + else + "cpu" + )) + ] + ++ lib.optionals rocmSupport [ + # ends up using g++ to build some files it shouldn't + (lib.cmakeFeature "CMAKE_C_COMPILER" "amdclang") + (lib.cmakeFeature "CMAKE_CXX_COMPILER" "amdclang++") + + (lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmGpuTargets)) ]; CUDA_HOME = lib.optionalString cudaSupport "${cuda-native-redist}"; NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [ @@ -112,17 +174,31 @@ buildPythonPackage { dependencies = [ scipy torch + trove-classifiers ]; doCheck = false; # tests require CUDA and also GPU access pythonImportsCheck = [ "bitsandbytes" ]; + passthru = { + inherit + cudaSupport + cudaPackages + rocmSupport + rocmPackages + brokenConditions # To help debug when a package is broken due to CUDA support + ; + }; + meta = { description = "8-bit CUDA functions for PyTorch"; homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes"; changelog = "https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/${version}"; license = lib.licenses.mit; - maintainers = with lib.maintainers; [ bcdarwin ]; + maintainers = with lib.maintainers; [ + bcdarwin + jk + ]; }; } diff --git a/pkgs/development/python-modules/bitsandbytes/find-rocm-deps-with-cmake.patch b/pkgs/development/python-modules/bitsandbytes/find-rocm-deps-with-cmake.patch new file mode 100644 index 0000000000000..2650cdb8cabf9 --- /dev/null +++ b/pkgs/development/python-modules/bitsandbytes/find-rocm-deps-with-cmake.patch @@ -0,0 +1,23 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 9c133e0..1aa8a53 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -285,6 +285,9 @@ if(BUILD_HIP) + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) ++ find_package_and_print_version(rocblas REQUIRED) ++ find_package_and_print_version(hip REQUIRED) ++ find_package_and_print_version(hipcub REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") +@@ -293,7 +296,7 @@ if(BUILD_HIP) + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) +- target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) ++ target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse roc::rocblas hip::device hip::hipcub) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)