Skip to content

Commit d7f3e15

Browse files
TimDettmersclaude
andcommitted
Activate unified CUDA/HIP kernel files from csrc/examples/
Move the unified portability-header-based files from csrc/examples/ into csrc/, replacing the duplicated CUDA and HIP kernel files. - Add compat.cuh and compat_device.cuh (portability headers) - Replace common.cuh, kernels.cu, ops.cu, ops.cuh, pythonInterface.cpp, CMakeLists.txt with unified versions - Update kernels.cuh: rename kQuantizeBlockwise32 -> kQuantizeBlockwiseSmall - Delete HIP-only files: common_hip.cuh, kernels.hip, kernels_hip.cuh, ops.hip, ops_hip.cuh - Delete csrc/examples/ (files are now in their final locations) Net: 10 source files -> 7, ~3300 fewer lines of duplicated code. Same .cu files compiled by both nvcc (CUDA) and hipcc (HIP). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b96843c commit d7f3e15

20 files changed

+485
-9024
lines changed

CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ endif()
2424

2525
# Define included source files
2626
set(CPP_FILES csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
27-
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
28-
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
27+
set(GPU_FILES csrc/ops.cu csrc/kernels.cu)
2928
set(MPS_FILES csrc/mps_ops.mm)
3029
set(METAL_FILES csrc/mps_kernels.metal)
3130
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
@@ -195,7 +194,7 @@ if(BUILD_CUDA)
195194
message(STATUS "CUDA Targets: ${CMAKE_CUDA_ARCHITECTURES}")
196195
message(STATUS "CUDA NVCC Flags: ${CMAKE_CUDA_FLAGS}")
197196

198-
list(APPEND SRC_FILES ${CUDA_FILES})
197+
list(APPEND SRC_FILES ${GPU_FILES})
199198

200199
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
201200
add_compile_definitions(BUILD_CUDA)
@@ -213,7 +212,7 @@ elseif(BUILD_HIP)
213212
endif()
214213
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
215214

216-
list(APPEND SRC_FILES ${HIP_FILES})
215+
list(APPEND SRC_FILES ${GPU_FILES})
217216

218217
string(APPEND BNB_OUTPUT_NAME "_rocm")
219218

@@ -339,7 +338,7 @@ if(BUILD_HIP)
339338
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
340339

341340
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
342-
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
341+
set_source_files_properties(${GPU_FILES} PROPERTIES LANGUAGE HIP)
343342
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
344343

345344
if(HIP_VERSION VERSION_LESS "6.1")

csrc/common.cuh

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,41 @@
1+
// common.cuh — Merged architecture constants for CUDA and HIP
2+
//
3+
// This replaces both the old csrc/common.cuh and csrc/common_hip.cuh.
4+
// Platform detection uses compat.cuh's BNB_HIP macro.
5+
16
#pragma once
27

3-
// TODO: Let's make some of these constexpr and put in a namespace.
8+
#include "compat.cuh"
9+
10+
// ============================================================================
11+
// Warp size
12+
// ============================================================================
13+
14+
#if BNB_HIP
15+
// AMD GFX9 (CDNA) uses 64-wide warps; RDNA uses 32-wide
16+
#ifdef __GFX9__
17+
#define BNB_WARP_SIZE 64
18+
#else
19+
#define BNB_WARP_SIZE 32
20+
#endif
21+
#else
22+
#define BNB_WARP_SIZE 32
23+
#endif
24+
25+
// ============================================================================
26+
// BF16 availability
27+
// ============================================================================
28+
29+
#if BNB_HIP
30+
// BF16 is available on all currently-supported ROCm architectures (CDNA2+, RDNA3+)
31+
#define BNB_BF16_AVAILABLE true
32+
#else
33+
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
34+
#endif
35+
36+
// ============================================================================
37+
// CUDA compute capability constants (CUDA-only, but harmless to define on HIP)
38+
// ============================================================================
439

540
#define BNB_CC_PASCAL 600
641
#define BNB_CC_PASCAL_X2 620
@@ -14,31 +49,45 @@
1449
#define BNB_CC_HOPPER 900
1550
#define BNB_CC_BLACKWELL 1000
1651

52+
// ============================================================================
53+
// Feature availability based on arch (CUDA uses __CUDA_ARCH__, HIP is simpler)
54+
// ============================================================================
55+
56+
#if BNB_HIP
57+
// HIP: MMA not supported via mma.h; FP8 support varies by arch
58+
#define BNB_FP16_MMA_AVAILABLE 0
59+
#define BNB_INT8_MMA_AVAILABLE 0
60+
#define BNB_FP8_AVAILABLE 0
61+
#else
1762
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
1863
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
19-
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
2064
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
65+
#endif
2166

22-
#define BNB_WARP_SIZE 32
67+
// ============================================================================
68+
// Maximum threads per SM/CU
69+
// ============================================================================
2370

24-
// The maximum number of resident threads per SM varies by arch.
25-
// For A100/H100 and all prior to Turing, it is 2048, which allows
26-
// for 2 full blocks of 1024 threads per SM.
27-
// Reference:
28-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
71+
#if BNB_HIP
72+
// For currently supported ROCm architectures (CDNA2, RDNA3)
73+
#define BNB_MAX_THREADS_PER_SM 2048
74+
#else
75+
// The maximum number of resident threads per SM varies by NVIDIA arch.
76+
// Reference: CUDA Programming Guide, Technical Specifications per Compute Capability
2977
#if __CUDA_ARCH__ == 750
3078
#define BNB_MAX_THREADS_PER_SM 1024
3179
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
3280
#define BNB_MAX_THREADS_PER_SM 1536
3381
#else
3482
#define BNB_MAX_THREADS_PER_SM 2048
3583
#endif
84+
#endif
3685

37-
// Maximum resident warps per SM is always directly related to the number of threads.
86+
// Maximum resident warps per SM/CU
3887
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
3988

40-
// Maximum resident blocks per SM may vary.
41-
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
89+
// Maximum resident blocks per SM/CU
90+
#if !BNB_HIP && (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870)
4291
#define BNB_MAX_BLOCKS_PER_SM 16
4392
#else
4493
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)

csrc/common_hip.cuh

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)