Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 91 additions & 15 deletions pkgs/development/python-modules/bitsandbytes/default.nix
Original file line number Diff line number Diff line change
@@ -1,27 +1,43 @@
{
lib,
torch,
stdenv,
symlinkJoin,
buildPythonPackage,
fetchFromGitHub,

cmake,

# build-system
scikit-build-core,
setuptools,

# dependencies
torch,
scipy,
trove-classifiers,

cudaSupport ? torch.cudaSupport,
cudaPackages ? torch.cudaPackages,
rocmSupport ? torch.rocmSupport,
rocmPackages ? torch.rocmPackages,

rocmGpuTargets ? rocmPackages.clr.localGpuTargets or rocmPackages.clr.gpuTargets,
}:

let
pname = "bitsandbytes";
version = "0.47.0";
version = "0.48.1";

brokenConditions = lib.attrsets.filterAttrs (_: cond: cond) {
"CUDA and ROCm are mutually exclusive" = cudaSupport && rocmSupport;
"CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"CUDA is not targeting Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;
"CUDA is only available for Linux" = cudaSupport && !stdenv.hostPlatform.isLinux;

};

inherit (torch) cudaPackages cudaSupport;
inherit (cudaPackages) cudaMajorMinorVersion;
rocmMajorMinorVersion = lib.versions.majorMinor rocmPackages.rocm-core.version;

cudaMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] cudaMajorMinorVersion;
rocmMajorMinorVersionString = lib.replaceStrings [ "." ] [ "" ] rocmMajorMinorVersion;

# NOTE: torchvision doesn't use cudnn; torch does!
# For this reason it is not included.
Expand Down Expand Up @@ -62,41 +78,87 @@ buildPythonPackage {
owner = "bitsandbytes-foundation";
repo = "bitsandbytes";
tag = version;
hash = "sha256-iUAeiNbPa3Q5jJ4lK2G0WvTKuipb0zO1mNe+wcRdnqs=";
hash = "sha256-OkhWv5Mb/cnWJteCXvDEkWQvK+QK26YQex39yWIezrQ=";
};

patches = [
./find-rocm-deps-with-cmake.patch
];

# By default, which library is loaded depends on the result of `torch.cuda.is_available()`.
# When `cudaSupport` is enabled, bypass this check and load the cuda library unconditionally.
# Indeed, in this case, only `libbitsandbytes_cuda124.so` is built. `libbitsandbytes_cpu.so` is not.
# Also, hardcode the path to the previously built library instead of relying on
# `get_cuda_bnb_library_path(cuda_specs)` which relies on `torch.cuda` too.
#
# WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for cuda version `xx.y`.
# WARNING: The cuda library is currently named `libbitsandbytes_cudaxxy` for CUDA version `xx.y`
# and `libbitsandbytes_rocmxxy` for ROCm version `xx.y`
# This upstream convention could change at some point and thus break the following patch.
postPatch = lib.optionalString cudaSupport ''
substituteInPlace bitsandbytes/cextension.py \
--replace-fail "if cuda_specs:" "if True:" \
--replace-fail \
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_cuda${cudaMajorMinorVersionString}.so'"
'';
postPatch = (
let
prefix = if cudaSupport then "cuda" else "rocm";
majorMinorVersionString =
if cudaSupport then cudaMajorMinorVersionString else rocmMajorMinorVersionString;
in
lib.optionalString (cudaSupport || rocmSupport) ''
substituteInPlace bitsandbytes/cextension.py \
--replace-fail "if cuda_specs:" "if True:" \
--replace-fail \
"cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)" \
"cuda_binary_path = PACKAGE_DIR / 'libbitsandbytes_${prefix}${majorMinorVersionString}.so'"
''
);

nativeBuildInputs = [
cmake
]
++ lib.optionals cudaSupport [
cudaPackages.cuda_nvcc
]
++ lib.optionals rocmSupport [
rocmPackages.clr
];

build-system = [
scikit-build-core
setuptools
];

buildInputs = lib.optionals cudaSupport [ cuda-redist ];
buildInputs =
lib.optional cudaSupport cuda-redist
++ lib.optionals rocmSupport (
with rocmPackages;
[
rocm-device-libs
hipblas
rocm-comgr
rocm-runtime
hiprand
rocrand
hipsparse
hipblaslt
rocblas
hipcub
rocprim
]
);

cmakeFlags = [
(lib.cmakeFeature "COMPUTE_BACKEND" (if cudaSupport then "cuda" else "cpu"))
(lib.cmakeFeature "COMPUTE_BACKEND" (
if cudaSupport then
"cuda"
else if rocmSupport then
"hip"
else
"cpu"
))
]
++ lib.optionals rocmSupport [
# ends up using g++ to build some files it shouldn't
(lib.cmakeFeature "CMAKE_C_COMPILER" "amdclang")
(lib.cmakeFeature "CMAKE_CXX_COMPILER" "amdclang++")

(lib.cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmGpuTargets))
];
CUDA_HOME = lib.optionalString cudaSupport "${cuda-native-redist}";
NVCC_PREPEND_FLAGS = lib.optionals cudaSupport [
Expand All @@ -112,17 +174,31 @@ buildPythonPackage {
dependencies = [
scipy
torch
trove-classifiers
];

doCheck = false; # tests require CUDA and also GPU access

pythonImportsCheck = [ "bitsandbytes" ];

passthru = {
inherit
cudaSupport
cudaPackages
rocmSupport
rocmPackages
brokenConditions # To help debug when a package is broken due to CUDA support
;
};

meta = {
description = "8-bit CUDA functions for PyTorch";
homepage = "https://github.com/bitsandbytes-foundation/bitsandbytes";
changelog = "https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/${version}";
license = lib.licenses.mit;
maintainers = with lib.maintainers; [ bcdarwin ];
maintainers = with lib.maintainers; [
bcdarwin
jk
];
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9c133e0..1aa8a53 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -285,6 +285,9 @@ if(BUILD_HIP)
find_package_and_print_version(hipblas REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)
+ find_package_and_print_version(rocblas REQUIRED)
+ find_package_and_print_version(hip REQUIRED)
+ find_package_and_print_version(hipcub REQUIRED)

## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
@@ -293,7 +296,7 @@ if(BUILD_HIP)

target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
- target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
+ target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse roc::rocblas hip::device hip::hipcub)

target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
Loading