Skip to content

Commit 3934632

Browse files
Abdennacer-BadaouiTimDettmersclaude
authored
Unify CUDA and HIP kernel sources via compat.cuh portability layer (#1877)
* first commit * update * Activate unified CUDA/HIP kernel files from csrc/examples/ Move the unified portability-header-based files from csrc/examples/ into csrc/, replacing the duplicated CUDA and HIP kernel files. - Add compat.cuh and compat_device.cuh (portability headers) - Replace common.cuh, kernels.cu, ops.cu, ops.cuh, pythonInterface.cpp, CMakeLists.txt with unified versions - Update kernels.cuh: rename kQuantizeBlockwise32 -> kQuantizeBlockwiseSmall - Delete HIP-only files: common_hip.cuh, kernels.hip, kernels_hip.cuh, ops.hip, ops_hip.cuh - Delete csrc/examples/ (files are now in their final locations) Net: 10 source files -> 7, ~3300 fewer lines of duplicated code. Same .cu files compiled by both nvcc (CUDA) and hipcc (HIP). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix HIP build errors in unified kernel files - Add hip/hip_bfloat16.h include to compat.cuh (bnb_bfloat16 type alias requires hip_bfloat16 to be defined) - Add __syncwarp() no-op macro for HIP (AMD warps are always in lockstep) - Add hipblas version check (#if hipblasVersionMajor >= 3) for GemmEx calls (ROCm 6.1 ships hipblas v2 which uses HIPBLAS_R_* not HIPBLAS_COMPUTE_*) - Fix include in ops.cuh: common.h -> common.cuh (BNB_WARP_SIZE visibility) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Restore common.h include in ops.cuh for DataType_t enum common.h defines General8bit, FP4, NF4 enum values used in template instantiations. It was previously the only include; now include both common.h (for DataType_t) and common.cuh (for BNB_WARP_SIZE). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Guard blocksize=64 quantize instantiations for warp size compatibility On AMD CDNA GPUs (warp size 64), blocksize=64 would mean only 1 thread per warp in the quantize kernels, which is incompatible. Wrap these instantiations with #if BNB_WARP_SIZE == 32 so they only compile on NVIDIA. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Guard all blocksize=64 quantize instantiations for warp size compat The previous commit missed the float/NF4 and all bnb_bfloat16 blocksize=64 instantiations. These use BLOCK_LOAD_WARP_TRANSPOSE with 32 threads (64/2), which requires block_dim >= warp_size. On CDNA (warp=64), 32 threads is insufficient. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Use conditional load/store algo for warp size compatibility BLOCK_LOAD_WARP_TRANSPOSE requires threads >= warp_size. On CDNA (warp=64), kQuantizeBlockwise with BLOCK_SIZE=64 has only 32 threads. Fall back to BLOCK_LOAD_DIRECT / BLOCK_STORE_DIRECT when threads < BNB_WARP_SIZE. This avoids rocprim compilation errors while keeping WARP_TRANSPOSE for larger block sizes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix BNB_WARP_SIZE detection for HIP host compilation pass __GFX9__ is only defined during the device compilation pass, not during host compilation. This caused BNB_WARP_SIZE to be 32 on the host pass even for gfx942 (CDNA, warp=64), making the conditional WARP_TRANSPOSE vs DIRECT selection wrong. Use __AMDGCN_WAVEFRONT_SIZE instead, which the HIP compiler defines correctly in both host and device passes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Remove blocksize=64 instantiation guards Now that kQuantizeBlockwise falls back to BLOCK_LOAD_DIRECT when threads < warp_size, the blocksize=64 instantiations compile correctly on both CUDA and HIP. The guards were causing linker errors because ops.cu still references these symbols for the General8bit dispatch path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Apply clang-format formatting fixes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * merge unified-hip-validation & code cleaning --------- Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 96b37ec commit 3934632

File tree

14 files changed

+644
-3023
lines changed

14 files changed

+644
-3023
lines changed

CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ endif()
5454

5555
# Define included source files
5656
set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
57-
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
58-
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
57+
set(GPU_FILES csrc/ops.cu csrc/kernels.cu)
5958
set(MPS_FILES csrc/mps_ops.mm)
6059
set(METAL_FILES csrc/mps_kernels.metal)
6160
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
@@ -225,7 +224,7 @@ if(BUILD_CUDA)
225224
message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}")
226225
message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}")
227226

228-
list(APPEND SRC_FILES ${CUDA_FILES})
227+
list(APPEND SRC_FILES ${GPU_FILES})
229228

230229
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
231230
add_compile_definitions(BUILD_CUDA)
@@ -244,7 +243,7 @@ elseif(BUILD_HIP)
244243
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
245244
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
246245

247-
list(APPEND SRC_FILES ${HIP_FILES})
246+
list(APPEND SRC_FILES ${GPU_FILES})
248247

249248
string(APPEND BNB_OUTPUT_NAME "_rocm")
250249

@@ -389,7 +388,7 @@ if(BUILD_HIP)
389388
endif()
390389

391390
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
392-
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
391+
set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP)
393392
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
394393

395394
if(HIP_VERSION VERSION_LESS "6.1")

csrc/common.cuh

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,32 @@
1+
// common.cuh — Architecture constants and feature detection
2+
13
#pragma once
24

3-
// TODO: Let's make some of these constexpr and put in a namespace.
5+
#include "compat.cuh"
6+
7+
// Warp size
8+
9+
#if BNB_HIP
10+
// CDNA (gfx9xx) = 64, RDNA = 32.
11+
#ifdef __AMDGCN_WAVEFRONT_SIZE
12+
#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
13+
#else
14+
#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA)
15+
#endif
16+
#else
17+
#define BNB_WARP_SIZE 32
18+
#endif
19+
20+
// BF16 availability
21+
22+
#if BNB_HIP
23+
// BF16 is available on all currently-supported ROCm architectures (CDNA2+, RDNA3+)
24+
#define BNB_BF16_AVAILABLE true
25+
#else
26+
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
27+
#endif
28+
29+
// Compute capability constants
430

531
#define BNB_CC_PASCAL 600
632
#define BNB_CC_PASCAL_X2 620
@@ -14,31 +40,41 @@
1440
#define BNB_CC_HOPPER 900
1541
#define BNB_CC_BLACKWELL 1000
1642

43+
// Feature availability based on arch
44+
45+
#if BNB_HIP
46+
// HIP: MMA not supported via mma.h; FP8 support varies by arch
47+
#define BNB_FP16_MMA_AVAILABLE 0
48+
#define BNB_INT8_MMA_AVAILABLE 0
49+
#define BNB_FP8_AVAILABLE 0
50+
#else
1751
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
1852
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
19-
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
2053
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
54+
#endif
2155

22-
#define BNB_WARP_SIZE 32
56+
// Maximum threads per SM/CU
2357

24-
// The maximum number of resident threads per SM varies by arch.
25-
// For A100/H100 and all prior to Turing, it is 2048, which allows
26-
// for 2 full blocks of 1024 threads per SM.
27-
// Reference:
28-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
58+
#if BNB_HIP
59+
// For currently supported ROCm architectures (CDNA2, RDNA3)
60+
#define BNB_MAX_THREADS_PER_SM 2048
61+
#else
62+
// The maximum number of resident threads per SM varies by NVIDIA arch.
63+
// Reference: CUDA Programming Guide, Technical Specifications per Compute Capability
2964
#if __CUDA_ARCH__ == 750
3065
#define BNB_MAX_THREADS_PER_SM 1024
3166
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
3267
#define BNB_MAX_THREADS_PER_SM 1536
3368
#else
3469
#define BNB_MAX_THREADS_PER_SM 2048
3570
#endif
71+
#endif
3672

37-
// Maximum resident warps per SM is always directly related to the number of threads.
73+
// Maximum resident warps per SM/CU
3874
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
3975

40-
// Maximum resident blocks per SM may vary.
41-
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
76+
// Maximum resident blocks per SM/CU
77+
#if !BNB_HIP && (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870)
4278
#define BNB_MAX_BLOCKS_PER_SM 16
4379
#else
4480
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)

csrc/common_hip.cuh

Lines changed: 0 additions & 11 deletions
This file was deleted.

csrc/compat.cuh

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// compat.cuh — Platform abstraction layer for CUDA/HIP portability
2+
//
3+
// This header resolves ALL mechanical differences between CUDA and HIP.
4+
// Kernel code should include this header and use the bnb_* types/macros
5+
// instead of cuda*/hip* identifiers directly.
6+
//
7+
// The guard macro is BNB_HIP, which is defined when compiling for ROCm/HIP
8+
// (set via CMakeLists.txt's add_compile_definitions(__HIP_PLATFORM_AMD__)).
9+
10+
#pragma once
11+
12+
// Platform detection
13+
14+
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
15+
#define BNB_HIP 1
16+
#else
17+
#define BNB_HIP 0
18+
#endif
19+
20+
// Runtime and FP16/BF16 headers
21+
22+
#if BNB_HIP
23+
24+
#include <hip/hip_bfloat16.h>
25+
#include <hip/hip_fp16.h>
26+
#include <hip/hip_math_constants.h>
27+
#include <hip/hip_runtime.h>
28+
#include <hipblas/hipblas.h>
29+
#include <rocblas/rocblas.h>
30+
31+
#else // CUDA
32+
33+
#include <cuda_bf16.h>
34+
#include <cuda_fp16.h>
35+
#include <cuda_runtime.h>
36+
37+
#endif
38+
39+
// Stream and error types
40+
41+
#if BNB_HIP
42+
43+
using bnb_stream_t = hipStream_t;
44+
using bnb_error_t = hipError_t;
45+
46+
#define BNB_SUCCESS hipSuccess
47+
#define BNB_PEEK_LAST_ERROR() hipPeekAtLastError()
48+
#define BNB_GET_ERROR_STRING(e) hipGetErrorString(e)
49+
#define BNB_DEVICE_MALLOC(p, s) hipMalloc(p, s)
50+
#define BNB_DEVICE_FREE(p) hipFree(p)
51+
#define BNB_DEVICE_MEMSET(p, v, s) hipMemset(p, v, s)
52+
53+
#else // CUDA
54+
55+
using bnb_stream_t = cudaStream_t;
56+
using bnb_error_t = cudaError_t;
57+
58+
#define BNB_SUCCESS cudaSuccess
59+
#define BNB_PEEK_LAST_ERROR() cudaPeekAtLastError()
60+
#define BNB_GET_ERROR_STRING(e) cudaGetErrorString(e)
61+
#define BNB_DEVICE_MALLOC(p, s) cudaMalloc(p, s)
62+
#define BNB_DEVICE_FREE(p) cudaFree(p)
63+
#define BNB_DEVICE_MEMSET(p, v, s) cudaMemset(p, v, s)
64+
65+
#endif
66+
67+
// Error checking
68+
69+
#define BNB_CHECK_RETURN(value) \
70+
{ \
71+
bnb_error_t _bnb_stat = value; \
72+
if (_bnb_stat != BNB_SUCCESS) { \
73+
fprintf(stderr, "Error %s at line %d in file %s\n", BNB_GET_ERROR_STRING(_bnb_stat), __LINE__, __FILE__); \
74+
exit(1); \
75+
} \
76+
}
77+
78+
// Keep backward compat for existing code during migration
79+
#define CUDA_CHECK_RETURN(value) BNB_CHECK_RETURN(value)
80+
81+
// Warp synchronization
82+
//
83+
// HIP warps are always in lockstep (no independent thread scheduling),
84+
// so __syncwarp() is a no-op. CUDA needs it for warp convergence.
85+
86+
#if BNB_HIP
87+
#define __syncwarp() \
88+
do { \
89+
} while (0)
90+
#endif
91+
92+
// BFloat16 type alias
93+
94+
#if BNB_HIP
95+
using bnb_bfloat16 = hip_bfloat16;
96+
#else
97+
using bnb_bfloat16 = __nv_bfloat16;
98+
#endif
99+
100+
// Data type enum aliases for BLAS libraries
101+
102+
#if BNB_HIP
103+
104+
#define BNB_R_16F HIP_R_16F
105+
#define BNB_R_32F HIP_R_32F
106+
#define BNB_R_8I HIP_R_8I
107+
#define BNB_R_32I HIP_R_32I
108+
109+
#else // CUDA
110+
111+
#define BNB_R_16F CUDA_R_16F
112+
#define BNB_R_32F CUDA_R_32F
113+
#define BNB_R_8I CUDA_R_8I
114+
#define BNB_R_32I CUDA_R_32I
115+
116+
#endif
117+
118+
// BLAS Lt types and functions
119+
120+
#if BNB_HIP
121+
122+
#ifndef NO_HIPBLASLT
123+
#include <hipblaslt/hipblaslt.h>
124+
#endif
125+
126+
using bnb_blasLt_handle_t = hipblasLtHandle_t;
127+
using bnb_blasLt_matmul_desc_t = hipblasLtMatmulDesc_t;
128+
using bnb_blasLt_layout_t = hipblasLtMatrixLayout_t;
129+
using bnb_blasLt_preference_t = hipblasLtMatmulPreference_t;
130+
131+
#define BNB_BLASLT_OP_T HIPBLAS_OP_T
132+
#define BNB_BLASLT_COMPUTE_32I HIPBLAS_COMPUTE_32I
133+
134+
#define bnb_blasLtCreate hipblasLtCreate
135+
#define bnb_blasLtMatmulDescCreate hipblasLtMatmulDescCreate
136+
#define bnb_blasLtMatmulDescSetAttr hipblasLtMatmulDescSetAttribute
137+
#define bnb_blasLtLayoutCreate hipblasLtMatrixLayoutCreate
138+
#define bnb_blasLtLayoutDestroy hipblasLtMatrixLayoutDestroy
139+
#define bnb_blasLtMatmulDescDestroy hipblasLtMatmulDescDestroy
140+
#define bnb_blasLtMatmul hipblasLtMatmul
141+
#define bnb_blasLtPrefCreate hipblasLtMatmulPreferenceCreate
142+
#define bnb_blasLtPrefSetAttr hipblasLtMatmulPreferenceSetAttribute
143+
#define bnb_blasLtAlgoGetHeuristic hipblasLtMatmulAlgoGetHeuristic
144+
145+
#define BNB_BLASLT_DESC_TRANSA HIPBLASLT_MATMUL_DESC_TRANSA
146+
#define BNB_BLASLT_DESC_POINTER_MODE HIPBLASLT_MATMUL_DESC_POINTER_MODE
147+
#define BNB_BLASLT_PREF_MAX_WORKSPACE HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
148+
#define BNB_BLASLT_PTR_MODE_ALPHA_VEC HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST
149+
150+
using bnb_blasLt_heuristic_t = hipblasLtMatmulHeuristicResult_t;
151+
using bnb_blas_status_t = hipblasStatus_t;
152+
#define BNB_BLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
153+
154+
#else // CUDA
155+
156+
#include <cublasLt.h>
157+
#include <cublas_v2.h>
158+
159+
using bnb_blasLt_handle_t = cublasLtHandle_t;
160+
using bnb_blasLt_matmul_desc_t = cublasLtMatmulDesc_t;
161+
using bnb_blasLt_layout_t = cublasLtMatrixLayout_t;
162+
163+
#define BNB_BLASLT_OP_T CUBLAS_OP_T
164+
#define BNB_BLASLT_COMPUTE_32I CUBLAS_COMPUTE_32I
165+
166+
#define bnb_blasLtCreate cublasLtCreate
167+
#define bnb_blasLtMatmulDescCreate cublasLtMatmulDescCreate
168+
#define bnb_blasLtMatmulDescSetAttr cublasLtMatmulDescSetAttribute
169+
#define bnb_blasLtLayoutCreate cublasLtMatrixLayoutCreate
170+
#define bnb_blasLtLayoutDestroy cublasLtMatrixLayoutDestroy
171+
#define bnb_blasLtMatmulDescDestroy cublasLtMatmulDescDestroy
172+
#define bnb_blasLtMatmul cublasLtMatmul
173+
174+
#define BNB_BLASLT_DESC_TRANSA CUBLASLT_MATMUL_DESC_TRANSA
175+
#define BNB_BLASLT_DESC_POINTER_MODE CUBLASLT_MATMUL_DESC_POINTER_MODE
176+
#define BNB_BLASLT_PTR_MODE_ALPHA_VEC CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO
177+
178+
using bnb_blas_status_t = cublasStatus_t;
179+
#define BNB_BLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS
180+
181+
#endif

csrc/compat_device.cuh

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// compat_device.cuh — Device-only portability layer (CUB, reduction ops, MMA)
2+
//
3+
// Include this from .cu kernel files only (compiled by nvcc/hipcc).
4+
// Do NOT include from .cpp files — use compat.cuh instead for host-safe types.
5+
6+
#pragma once
7+
8+
#include "compat.cuh"
9+
10+
// CUB / hipCUB — namespace alias
11+
12+
#if BNB_HIP
13+
14+
#include <hipcub/hipcub.hpp>
15+
namespace bnb_cub = hipcub;
16+
17+
#else // CUDA
18+
19+
#include <cub/block/block_discontinuity.cuh>
20+
#include <cub/block/block_load.cuh>
21+
#include <cub/block/block_radix_sort.cuh>
22+
#include <cub/block/block_reduce.cuh>
23+
#include <cub/block/block_store.cuh>
24+
#include <cub/cub.cuh>
25+
#include <cub/warp/warp_reduce.cuh>
26+
#include <math_constants.h>
27+
#include <mma.h>
28+
namespace bnb_cub = cub;
29+
30+
#endif
31+
32+
// Reduction operators
33+
34+
#if BNB_HIP
35+
36+
#define BNB_MAX_OP hipcub::Max()
37+
#define BNB_SUM_OP hipcub::Sum()
38+
39+
#else // CUDA
40+
41+
// CCCL 2.8.2+ moved to cuda::maximum<>{}, older versions use cub::Max()
42+
#if defined(CCCL_VERSION) && CCCL_VERSION >= 2008002
43+
#include <cuda/std/functional>
44+
#define BNB_MAX_OP \
45+
cuda::maximum<> {}
46+
#else
47+
#define BNB_MAX_OP cub::Max()
48+
#endif
49+
#define BNB_SUM_OP cub::Sum()
50+
51+
#endif

0 commit comments

Comments
 (0)