Skip to content

Commit 83216e0

Browse files
first commit
1 parent 9dd8b70 commit 83216e0

6 files changed

Lines changed: 1559 additions & 0 deletions

File tree

csrc/compat.cuh

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
// compat.cuh — Platform abstraction layer for CUDA/HIP portability
2+
//
3+
// This header resolves ALL mechanical differences between CUDA and HIP.
4+
// Kernel code should include this header and use the bnb_* types/macros
5+
// instead of cuda*/hip* identifiers directly.
6+
//
7+
// The guard macro is BNB_HIP, which is defined when compiling for ROCm/HIP
8+
// (set via CMakeLists.txt's add_compile_definitions(__HIP_PLATFORM_AMD__)).
9+
10+
#pragma once
11+
12+
// ============================================================================
13+
// Platform detection
14+
// ============================================================================
15+
16+
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
17+
#define BNB_HIP 1
18+
#else
19+
#define BNB_HIP 0
20+
#endif
21+
22+
// ============================================================================
23+
// Runtime and FP16/BF16 headers
24+
// ============================================================================
25+
26+
#if BNB_HIP
27+
28+
#include <hip/hip_fp16.h>
29+
#include <hip/hip_math_constants.h>
30+
#include <hip/hip_runtime.h>
31+
32+
#else // CUDA
33+
34+
#include <cuda_bf16.h>
35+
#include <cuda_fp16.h>
36+
#include <cuda_runtime.h>
37+
#include <math_constants.h>
38+
#include <mma.h>
39+
40+
#endif
41+
42+
// ============================================================================
43+
// CUB / hipCUB — namespace alias
44+
//
45+
// Usage: bnb_cub::BlockLoad<...>, bnb_cub::BlockReduce<...>, etc.
46+
// This single alias eliminates ~90% of the cub:: vs hipcub:: differences.
47+
// ============================================================================
48+
49+
#if BNB_HIP
50+
51+
#include <hipcub/hipcub.hpp>
52+
namespace bnb_cub = hipcub;
53+
54+
#else // CUDA
55+
56+
#include <cub/block/block_discontinuity.cuh>
57+
#include <cub/block/block_load.cuh>
58+
#include <cub/block/block_radix_sort.cuh>
59+
#include <cub/block/block_reduce.cuh>
60+
#include <cub/block/block_store.cuh>
61+
#include <cub/cub.cuh>
62+
#include <cub/warp/warp_reduce.cuh>
63+
namespace bnb_cub = cub;
64+
65+
#endif
66+
67+
// ============================================================================
68+
// Reduction operators — CUB's Max()/Sum() API differs across versions
69+
// ============================================================================
70+
71+
#if BNB_HIP
72+
73+
#define BNB_MAX_OP hipcub::Max()
74+
#define BNB_SUM_OP hipcub::Sum()
75+
76+
#else // CUDA
77+
78+
// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max()
79+
#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002
80+
#include <cuda/std/functional>
81+
#define BNB_MAX_OP \
82+
cuda::maximum<> {}
83+
#else
84+
#define BNB_MAX_OP cub::Max()
85+
#endif
86+
#define BNB_SUM_OP cub::Sum()
87+
88+
#endif
89+
90+
// ============================================================================
91+
// Stream and error types
92+
// ============================================================================
93+
94+
#if BNB_HIP
95+
96+
using bnb_stream_t = hipStream_t;
97+
using bnb_error_t = hipError_t;
98+
99+
#define BNB_SUCCESS hipSuccess
100+
#define BNB_PEEK_LAST_ERROR() hipPeekAtLastError()
101+
#define BNB_GET_ERROR_STRING(e) hipGetErrorString(e)
102+
#define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s)
103+
#define BNB_DEVICE_FREE(p) hipFree(p)
104+
105+
#else // CUDA
106+
107+
using bnb_stream_t = cudaStream_t;
108+
using bnb_error_t = cudaError_t;
109+
110+
#define BNB_SUCCESS cudaSuccess
111+
#define BNB_PEEK_LAST_ERROR() cudaPeekAtLastError()
112+
#define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e)
113+
#define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s)
114+
#define BNB_DEVICE_FREE(p) cudaFree(p)
115+
116+
#endif
117+
118+
// ============================================================================
119+
// Error checking macro (unified name, platform-specific implementation)
120+
// ============================================================================
121+
122+
#define BNB_CHECK_RETURN(value) \
123+
{ \
124+
bnb_error_t _bnb_stat = value; \
125+
if (_bnb_stat != BNB_SUCCESS) { \
126+
fprintf(stderr, "Error %s at line %d in file %s\n", BNB_GET_ERROR_STRING(_bnb_stat), __LINE__, __FILE__); \
127+
exit(1); \
128+
} \
129+
}
130+
131+
// Keep backward compat for existing code during migration
132+
#define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value)
133+
134+
// ============================================================================
135+
// BFloat16 type alias
136+
//
137+
// CUDA uses __nv_bfloat16, HIP uses hip_bfloat16. Unified as bnb_bfloat16.
138+
// ============================================================================
139+
140+
#if BNB_HIP
141+
using bnb_bfloat16 = hip_bfloat16;
142+
#else
143+
using bnb_bfloat16 = __nv_bfloat16;
144+
#endif
145+
146+
// ============================================================================
147+
// Data type enum aliases for BLAS/Sparse libraries
148+
// ============================================================================
149+
150+
#if BNB_HIP
151+
152+
#define BNB_R_16F HIP_R_16F
153+
#define BNB_R_32F HIP_R_32F
154+
#define BNB_R_8I HIP_R_8I
155+
#define BNB_R_32I HIP_R_32I
156+
157+
#else // CUDA
158+
159+
#define BNB_R_16F CUDA_R_16F
160+
#define BNB_R_32F CUDA_R_32F
161+
#define BNB_R_8I CUDA_R_8I
162+
#define BNB_R_32I CUDA_R_32I
163+
164+
#endif
165+
166+
// ============================================================================
167+
// BLAS Lt types and functions
168+
// ============================================================================
169+
170+
#if BNB_HIP
171+
172+
#ifndef NO_HIPBLASLT
173+
#include <hipblaslt/hipblaslt.h>
174+
#endif
175+
176+
using bnb_blasLt_handle_t = hipblasLtHandle_t;
177+
using bnb_blasLt_matmul_desc_t = hipblasLtMatmulDesc_t;
178+
using bnb_blasLt_layout_t = hipblasLtMatrixLayout_t;
179+
using bnb_blasLt_preference_t = hipblasLtMatmulPreference_t;
180+
181+
#define BNB_BLASLT_OP_T HIPBLAS_OP_T
182+
#define BNB_BLASLT_COMPUTE_32I HIPBLAS_COMPUTE_32I
183+
184+
#define bnb_blasLtCreate hipblasLtCreate
185+
#define bnb_blasLtMatmulDescCreate hipblasLtMatmulDescCreate
186+
#define bnb_blasLtMatmulDescSetAttr hipblasLtMatmulDescSetAttribute
187+
#define bnb_blasLtLayoutCreate hipblasLtMatrixLayoutCreate
188+
#define bnb_blasLtLayoutDestroy hipblasLtMatrixLayoutDestroy
189+
#define bnb_blasLtMatmulDescDestroy hipblasLtMatmulDescDestroy
190+
#define bnb_blasLtMatmul hipblasLtMatmul
191+
#define bnb_blasLtPrefCreate hipblasLtMatmulPreferenceCreate
192+
#define bnb_blasLtPrefSetAttr hipblasLtMatmulPreferenceSetAttribute
193+
#define bnb_blasLtAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic
194+
195+
#define BNB_BLASLT_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA
196+
#define BNB_BLASLT_DESC_POINTER_MODE HIPBLASLT_MATMUL_DESC_POINTER_MODE
197+
#define BNB_BLASLT_PREF_MAX_WORKSPACE HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
198+
#define BNB_BLASLT_PTR_MODE_ALPHA_VEC HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST
199+
200+
using bnb_blasLt_heuristic_t = hipblasLtMatmulHeuristicResult_t;
201+
using bnb_blas_status_t = hipblasStatus_t;
202+
#define BNB_BLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
203+
204+
#else // CUDA
205+
206+
#include <cublasLt.h>
207+
#include <cublas_v2.h>
208+
209+
using bnb_blasLt_handle_t = cublasLtHandle_t;
210+
using bnb_blasLt_matmul_desc_t = cublasLtMatmulDesc_t;
211+
using bnb_blasLt_layout_t = cublasLtMatrixLayout_t;
212+
213+
#define BNB_BLASLT_OP_T CUBLAS_OP_T
214+
#define BNB_BLASLT_COMPUTE_32I CUBLAS_COMPUTE_32I
215+
216+
#define bnb_blasLtCreate cublasLtCreate
217+
#define bnb_blasLtMatmulDescCreate cublasLtMatmulDescCreate
218+
#define bnb_blasLtMatmulDescSetAttr cublasLtMatmulDescSetAttribute
219+
#define bnb_blasLtLayoutCreate cublasLtMatrixLayoutCreate
220+
#define bnb_blasLtLayoutDestroy cublasLtMatrixLayoutDestroy
221+
#define bnb_blasLtMatmulDescDestroy cublasLtMatmulDescDestroy
222+
#define bnb_blasLtMatmul cublasLtMatmul
223+
224+
#define BNB_BLASLT_DESC_TRANSA CUBLASLT_MATMUL_DESC_TRANSA
225+
#define BNB_BLASLT_DESC_POINTER_MODE CUBLASLT_MATMUL_DESC_POINTER_MODE
226+
#define BNB_BLASLT_PTR_MODE_ALPHA_VEC CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO
227+
228+
using bnb_blas_status_t = cublasStatus_t;
229+
#define BNB_BLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS
230+
231+
#endif
232+
233+
// ============================================================================
234+
// Sparse library types
235+
// ============================================================================
236+
237+
#if BNB_HIP
238+
239+
#include <hipsparse/hipsparse.h>
240+
241+
using bnb_sparse_handle_t = hipsparseHandle_t;
242+
243+
#define bnb_sparseCreate hipsparseCreate
244+
#define bnb_sparseCreateCoo hipsparseCreateCoo
245+
#define bnb_sparseCreateDnMat hipsparseCreateDnMat
246+
#define bnb_sparseSpMM_bufSize hipsparseSpMM_bufferSize
247+
#define bnb_sparseSpMM hipsparseSpMM
248+
#define bnb_sparseDestroySpMat hipsparseDestroySpMat
249+
#define bnb_sparseDestroyDnMat hipsparseDestroyDnMat
250+
251+
#define BNB_SPARSE_INDEX_32I HIPSPARSE_INDEX_32I
252+
#define BNB_SPARSE_INDEX_BASE_ZERO HIPSPARSE_INDEX_BASE_ZERO
253+
#define BNB_SPARSE_ORDER_ROW HIPSPARSE_ORDER_ROW
254+
#define BNB_SPARSE_OP_NON_TRANSPOSE HIPSPARSE_OPERATION_NON_TRANSPOSE
255+
#define BNB_SPARSE_OP_TRANSPOSE HIPSPARSE_OPERATION_TRANSPOSE
256+
#define BNB_SPARSE_SPMM_ALG_DEFAULT HIPSPARSE_SPMM_ALG_DEFAULT
257+
258+
#define CHECK_SPARSE(value) \
259+
{ \
260+
hipsparseStatus_t _stat = value; \
261+
if (_stat != HIPSPARSE_STATUS_SUCCESS) { \
262+
fprintf(stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_stat), __LINE__, __FILE__); \
263+
exit(1); \
264+
} \
265+
}
266+
267+
#else // CUDA
268+
269+
#include <cusparse.h>
270+
271+
using bnb_sparse_handle_t = cusparseHandle_t;
272+
273+
#define bnb_sparseCreate cusparseCreate
274+
#define bnb_sparseCreateCoo cusparseCreateCoo
275+
#define bnb_sparseCreateDnMat cusparseCreateDnMat
276+
#define bnb_sparseSpMM_bufSize cusparseSpMM_bufferSize
277+
#define bnb_sparseSpMM cusparseSpMM
278+
#define bnb_sparseDestroySpMat cusparseDestroySpMat
279+
#define bnb_sparseDestroyDnMat cusparseDestroyDnMat
280+
281+
#define BNB_SPARSE_INDEX_32I CUSPARSE_INDEX_32I
282+
#define BNB_SPARSE_INDEX_BASE_ZERO CUSPARSE_INDEX_BASE_ZERO
283+
#define BNB_SPARSE_ORDER_ROW CUSPARSE_ORDER_ROW
284+
#define BNB_SPARSE_OP_NON_TRANSPOSE CUSPARSE_OPERATION_NON_TRANSPOSE
285+
#define BNB_SPARSE_OP_TRANSPOSE CUSPARSE_OPERATION_TRANSPOSE
286+
#define BNB_SPARSE_SPMM_ALG_DEFAULT CUSPARSE_SPMM_ALG_DEFAULT
287+
288+
#define CHECK_SPARSE(value) \
289+
{ \
290+
cusparseStatus_t _stat = value; \
291+
if (_stat != CUSPARSE_STATUS_SUCCESS) { \
292+
fprintf(stderr, "Error %s at line %d in file %s\n", cusparseGetErrorString(_stat), __LINE__, __FILE__); \
293+
exit(1); \
294+
} \
295+
}
296+
297+
#endif
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# CMakeLists.txt Changes for Unified Kernels
2+
3+
## Summary of changes
4+
5+
Replace separate `CUDA_FILES` and `HIP_FILES` with a single `GPU_FILES` list.
6+
For HIP builds, tell CMake to compile `.cu` files using the HIP language.
7+
8+
## Diff
9+
10+
```diff
11+
# Define included source files
12+
set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
13+
-set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
14+
-set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
15+
+set(GPU_FILES csrc/ops.cu csrc/kernels.cu)
16+
set(MPS_FILES csrc/mps_ops.mm)
17+
set(METAL_FILES csrc/mps_kernels.metal)
18+
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
19+
```
20+
21+
```diff
22+
if(BUILD_CUDA)
23+
# ... (CUDA setup unchanged)
24+
- list(APPEND SRC_FILES ${CUDA_FILES})
25+
+ list(APPEND SRC_FILES ${GPU_FILES})
26+
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
27+
add_compile_definitions(BUILD_CUDA)
28+
elseif(BUILD_HIP)
29+
# ... (HIP setup unchanged)
30+
- list(APPEND SRC_FILES ${HIP_FILES})
31+
+ list(APPEND SRC_FILES ${GPU_FILES})
32+
string(APPEND BNB_OUTPUT_NAME "_rocm")
33+
# ...
34+
```
35+
36+
```diff
37+
if(BUILD_HIP)
38+
# ...
39+
- set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
40+
+ set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP)
41+
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
42+
# ...
43+
endif()
44+
```
45+
46+
## Files to delete after migration
47+
48+
- `csrc/common_hip.cuh`
49+
- `csrc/kernels.hip`
50+
- `csrc/kernels_hip.cuh`
51+
- `csrc/ops.hip`
52+
- `csrc/ops_hip.cuh`

0 commit comments

Comments
 (0)