Skip to content

Commit 56c48bc

Browse files
authored
Merge branch 'main' into main
2 parents b88236a + fd2949a commit 56c48bc

24 files changed

Lines changed: 4792 additions & 106 deletions

.github/scripts/build-rocm.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
declare build_arch
3+
declare build_os
4+
declare rocm_version
5+
6+
set -xeuo pipefail
7+
bnb_rocm_arch="gfx90a;gfx942;gfx1100"
8+
if [ "${build_os:0:6}" == ubuntu ]; then
9+
image=rocm/dev-ubuntu-22.04:${rocm_version}-complete
10+
echo "Using image $image"
11+
docker run --rm --platform "linux/$build_arch" -i \
12+
-w /src -v "$PWD:/src" "$image" sh -c \
13+
"apt-get update \
14+
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
15+
&& cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \
16+
&& cmake --build ."
17+
fi
18+
19+
output_dir="output/${build_os}/${build_arch}"
20+
mkdir -p "${output_dir}"
21+
(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}")

.github/workflows/python-package.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,55 @@ jobs:
102102
path: output/*
103103
retention-days: 7
104104

105+
build-shared-libs-rocm:
106+
strategy:
107+
matrix:
108+
os: [ubuntu-22.04]
109+
arch: [x86_64]
110+
rocm_version:
111+
["6.1.2", "6.2.4", "6.3.2"]
112+
runs-on: ${{ matrix.os }}
113+
steps:
114+
- uses: actions/checkout@v4
115+
- name: Set up Docker multiarch
116+
uses: docker/setup-qemu-action@v3
117+
- name: Clean up disk space
118+
run: |
119+
sudo rm -rf \
120+
/usr/share/dotnet \
121+
/opt/ghc \
122+
"/usr/local/share/boost" \
123+
"$AGENT_TOOLSDIRECTORY" \
124+
/opt/hostedtoolcache \
125+
/opt/google/chrome \
126+
/opt/microsoft/msedge \
127+
/opt/microsoft/powershell \
128+
/opt/pipx \
129+
/usr/lib/mono \
130+
/usr/local/julia* \
131+
/usr/local/lib/android \
132+
/usr/local/lib/node_modules \
133+
/usr/local/share/chromium \
134+
/usr/local/share/powershell \
135+
/usr/share/swift
136+
- name: Build C++
137+
run: bash .github/scripts/build-rocm.sh
138+
env:
139+
build_os: ${{ matrix.os }}
140+
build_arch: ${{ matrix.arch }}
141+
rocm_version: ${{ matrix.rocm_version }}
142+
- name: Upload build artifact
143+
uses: actions/upload-artifact@v4
144+
with:
145+
name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }}
146+
path: output/*
147+
retention-days: 7
148+
105149
build-wheels:
106150
needs:
107151
- build-shared-libs
108152
- build-shared-libs-cuda
153+
- build-shared-libs-rocm
109154
strategy:
110155
matrix:
111156
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest]
@@ -173,6 +218,7 @@ jobs:
173218
merge-multiple: true
174219

175220
- name: Inspect tmp directory after downloading artifacts
221+
176222
run: |
177223
ls -alFR tmp/
178224
WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l)
@@ -210,6 +256,7 @@ jobs:
210256
- uses: actions/checkout@v4
211257
with:
212258
path: repo
259+
213260
- name: Delete old pre-release (if exists)
214261
run: |
215262
cd repo && gh release delete continuous-release_main --cleanup-tag -y

CMakeLists.txt

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ endif()
2525
# Define included source files
2626
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
2727
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
28+
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
2829
set(MPS_FILES csrc/mps_ops.mm)
2930
set(METAL_FILES csrc/mps_kernels.metal)
3031
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
3132
# C++ sources are always included
3233
list(APPEND SRC_FILES ${CPP_FILES})
3334

34-
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps, xpu)")
35-
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps xpu)
35+
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
36+
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
3637
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
3738

3839
if(APPLE)
@@ -48,12 +49,21 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
4849
message(FATAL_ERROR "CUDA is not supported on macOS" )
4950
endif()
5051
set(BUILD_CUDA ON)
52+
set(BUILD_HIP OFF)
53+
set(BUILD_MPS OFF)
54+
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
55+
if(APPLE)
56+
message(FATAL_ERROR "HIP is not supported on macOS" )
57+
endif()
58+
set(BUILD_CUDA OFF)
59+
set(BUILD_HIP ON)
5160
set(BUILD_MPS OFF)
5261
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
5362
if(NOT APPLE)
5463
message(FATAL_ERROR "MPS is only supported on macOS" )
5564
endif()
5665
set(BUILD_CUDA OFF)
66+
set(BUILD_HIP OFF)
5767
set(BUILD_MPS ON)
5868
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
5969
if(APPLE)
@@ -64,6 +74,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
6474
set(BUILD_XPU ON)
6575
else()
6676
set(BUILD_CUDA OFF)
77+
set(BUILD_HIP OFF)
6778
set(BUILD_MPS OFF)
6879
set(BUILD_XPU OFF)
6980
endif()
@@ -169,6 +180,33 @@ if(BUILD_CUDA)
169180

170181
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
171182
add_compile_definitions(BUILD_CUDA)
183+
elseif(BUILD_HIP)
184+
enable_language(HIP)
185+
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
186+
if(DEFINED BNB_ROCM_ARCH)
187+
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
188+
else()
189+
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
190+
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100")
191+
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
192+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
193+
endif()
194+
endif()
195+
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
196+
197+
list(APPEND SRC_FILES ${HIP_FILES})
198+
199+
string(APPEND BNB_OUTPUT_NAME "_rocm")
200+
201+
# get hip version
202+
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
203+
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
204+
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")
205+
206+
string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
207+
add_compile_definitions(__HIP_PLATFORM_AMD__)
208+
add_compile_definitions(__HIP_PLATFORM_HCC__)
209+
add_compile_definitions(BUILD_HIP)
172210
elseif(BUILD_MPS)
173211
if(NOT APPLE)
174212
message(FATAL_ERROR "MPS is only supported on macOS" )
@@ -223,6 +261,41 @@ if(BUILD_CUDA)
223261
CUDA_SEPARABLE_COMPILATION ON
224262
)
225263
endif()
264+
if(BUILD_HIP)
265+
if(NOT DEFINED ENV{ROCM_PATH})
266+
set(ROCM_PATH /opt/rocm)
267+
else()
268+
set(ROCM_PATH $ENV{ROCM_PATH})
269+
endif()
270+
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
271+
macro(find_package_and_print_version PACKAGE_NAME)
272+
find_package("${PACKAGE_NAME}" ${ARGN})
273+
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
274+
endmacro()
275+
find_package_and_print_version(hipblas REQUIRED)
276+
find_package_and_print_version(hiprand REQUIRED)
277+
find_package_and_print_version(hipsparse REQUIRED)
278+
279+
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
280+
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
281+
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
282+
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
283+
284+
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
285+
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
286+
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
287+
288+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
289+
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
290+
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
291+
292+
if(HIP_VERSION VERSION_LESS "6.1")
293+
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
294+
else()
295+
find_package(hipblaslt)
296+
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
297+
endif()
298+
endif()
226299
if(BUILD_MPS)
227300
add_dependencies(bitsandbytes metallib)
228301
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")

README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ bitsandbytes has the following minimum requirements for all platforms:
4848
</thead>
4949
<tbody>
5050
<tr>
51-
<td colspan="4">🐧 <strong>Linux, glibc >= 2.24</strong></td>
51+
<td colspan="6">🐧 <strong>Linux, glibc >= 2.24</strong></td>
5252
</tr>
5353
<tr>
5454
<td align="right">x86-64</td>
5555
<td>◻️ CPU</td>
5656
<td>AVX2</td>
57-
<td>〰️</td>
58-
<td>〰️</td>
57+
<td></td>
58+
<td></td>
5959
<td>❌</td>
6060
</tr>
6161
<tr>
@@ -93,16 +93,16 @@ bitsandbytes has the following minimum requirements for all platforms:
9393
<td></td>
9494
<td>🟪 Intel Gaudi <br><code>hpu</code></td>
9595
<td>Gaudi1, Gaudi2, Gaudi3</td>
96-
<td>🚧</td>
97-
<td>🚧</td>
96+
<td></td>
97+
<td>〰️</td>
9898
<td>❌</td>
9999
</tr>
100100
<tr>
101101
<td align="right">aarch64</td>
102102
<td>◻️ CPU</td>
103103
<td></td>
104-
<td>〰️</td>
105-
<td>〰️</td>
104+
<td></td>
105+
<td></td>
106106
<td>❌</td>
107107
</tr>
108108
<tr>
@@ -114,14 +114,14 @@ bitsandbytes has the following minimum requirements for all platforms:
114114
<td>✅</td>
115115
</tr>
116116
<tr>
117-
<td colspan="4">🪟 <strong>Windows 11 / Windows Server 2019+</strong></td>
117+
<td colspan="6">🪟 <strong>Windows 11 / Windows Server 2019+</strong></td>
118118
</tr>
119119
<tr>
120120
<td align="right">x86-64</td>
121121
<td>◻️ CPU</td>
122122
<td>AVX2</td>
123-
<td>〰️</td>
124-
<td>〰️</td>
123+
<td></td>
124+
<td></td>
125125
<td>❌</td>
126126
</tr>
127127
<tr>
@@ -144,7 +144,7 @@ bitsandbytes has the following minimum requirements for all platforms:
144144
<td>🚧</td>
145145
</tr>
146146
<tr>
147-
<td colspan="4">🍎 <strong>macOS 13.1+</strong></td>
147+
<td colspan="6">🍎 <strong>macOS 14+</strong></td>
148148
</tr>
149149
<tr>
150150
<td align="right">arm64</td>

bitsandbytes/backends/cuda/ops.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import lib
11+
from ...cextension import HIP_ENVIRONMENT, lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -210,7 +210,12 @@ def _get_col_absmax(
210210
@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
213+
214+
if HIP_ENVIRONMENT:
215+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216+
else:
217+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
218+
214219
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
215220

216221
n = A.numel()
@@ -264,7 +269,11 @@ def _(
264269
def _dequantize_blockwise_impl(
265270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
266271
) -> None:
267-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
272+
if HIP_ENVIRONMENT:
273+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274+
else:
275+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
276+
268277
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
269278
torch._check(
270279
dtype in [torch.float16, torch.bfloat16, torch.float32],
@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl(
294303
def _(
295304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
296305
) -> tuple[torch.Tensor, torch.Tensor]:
297-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
306+
if HIP_ENVIRONMENT:
307+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308+
else:
309+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
310+
298311
torch._check(quant_type in ["fp4", "nf4"])
299312
torch._check(
300313
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
@@ -372,7 +385,11 @@ def _dequantize_4bit_impl(
372385
dtype: torch.dtype,
373386
out: torch.Tensor,
374387
) -> None:
375-
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
388+
if HIP_ENVIRONMENT:
389+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390+
else:
391+
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
392+
376393
torch._check(quant_type in ["fp4", "nf4"])
377394
torch._check(
378395
dtype in [torch.bfloat16, torch.float16, torch.float32],

0 commit comments

Comments
 (0)