|
| 1 | +// common.cuh — Merged architecture constants for CUDA and HIP |
| 2 | +// |
| 3 | +// This replaces both the old csrc/common.cuh and csrc/common_hip.cuh. |
| 4 | +// Platform detection uses compat.cuh's BNB_HIP macro. |
| 5 | + |
1 | 6 | #pragma once |
2 | 7 |
|
3 | | -// TODO: Let's make some of these constexpr and put in a namespace. |
| 8 | +#include "compat.cuh" |
| 9 | + |
| 10 | +// ============================================================================ |
| 11 | +// Warp size |
| 12 | +// ============================================================================ |
| 13 | + |
| 14 | +#if BNB_HIP |
| 15 | +// AMD GFX9 (CDNA) uses 64-wide warps; RDNA uses 32-wide |
| 16 | +#ifdef __GFX9__ |
| 17 | +#define BNB_WARP_SIZE 64 |
| 18 | +#else |
| 19 | +#define BNB_WARP_SIZE 32 |
| 20 | +#endif |
| 21 | +#else |
| 22 | +#define BNB_WARP_SIZE 32 |
| 23 | +#endif |
| 24 | + |
| 25 | +// ============================================================================ |
| 26 | +// BF16 availability |
| 27 | +// ============================================================================ |
| 28 | + |
| 29 | +#if BNB_HIP |
| 30 | +// BF16 is available on all currently-supported ROCm architectures (CDNA2+, RDNA3+) |
| 31 | +#define BNB_BF16_AVAILABLE true |
| 32 | +#else |
| 33 | +#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) |
| 34 | +#endif |
| 35 | + |
| 36 | +// ============================================================================ |
| 37 | +// CUDA compute capability constants (CUDA-only, but harmless to define on HIP) |
| 38 | +// ============================================================================ |
4 | 39 |
|
5 | 40 | #define BNB_CC_PASCAL 600 |
6 | 41 | #define BNB_CC_PASCAL_X2 620 |
|
14 | 49 | #define BNB_CC_HOPPER 900 |
15 | 50 | #define BNB_CC_BLACKWELL 1000 |
16 | 51 |
|
| 52 | +// ============================================================================ |
| 53 | +// Feature availability based on arch (CUDA uses __CUDA_ARCH__, HIP is simpler) |
| 54 | +// ============================================================================ |
| 55 | + |
| 56 | +#if BNB_HIP |
| 57 | +// HIP: MMA not supported via mma.h; FP8 support varies by arch |
| 58 | +#define BNB_FP16_MMA_AVAILABLE 0 |
| 59 | +#define BNB_INT8_MMA_AVAILABLE 0 |
| 60 | +#define BNB_FP8_AVAILABLE 0 |
| 61 | +#else |
17 | 62 | #define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA) |
18 | 63 | #define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER) |
19 | | -#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE) |
20 | 64 | #define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA) |
| 65 | +#endif |
21 | 66 |
|
22 | | -#define BNB_WARP_SIZE 32 |
| 67 | +// ============================================================================ |
| 68 | +// Maximum threads per SM/CU |
| 69 | +// ============================================================================ |
23 | 70 |
|
24 | | -// The maximum number of resident threads per SM varies by arch. |
25 | | -// For A100/H100 and all prior to Turing, it is 2048, which allows |
26 | | -// for 2 full blocks of 1024 threads per SM. |
27 | | -// Reference: |
28 | | -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability |
| 71 | +#if BNB_HIP |
| 72 | +// For currently supported ROCm architectures (CDNA2, RDNA3) |
| 73 | +#define BNB_MAX_THREADS_PER_SM 2048 |
| 74 | +#else |
| 75 | +// The maximum number of resident threads per SM varies by NVIDIA arch. |
| 76 | +// Reference: CUDA Programming Guide, Technical Specifications per Compute Capability |
29 | 77 | #if __CUDA_ARCH__ == 750 |
30 | 78 | #define BNB_MAX_THREADS_PER_SM 1024 |
31 | 79 | #elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890 |
32 | 80 | #define BNB_MAX_THREADS_PER_SM 1536 |
33 | 81 | #else |
34 | 82 | #define BNB_MAX_THREADS_PER_SM 2048 |
35 | 83 | #endif |
| 84 | +#endif |
36 | 85 |
|
37 | | -// Maximum resident warps per SM is always directly related to the number of threads. |
| 86 | +// Maximum resident warps per SM/CU |
38 | 87 | #define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE)) |
39 | 88 |
|
40 | | -// Maximum resident blocks per SM may vary. |
41 | | -#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 |
| 89 | +// Maximum resident blocks per SM/CU |
| 90 | +#if !BNB_HIP && (defined(__CUDA_ARCH__)) && (__CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870) |
42 | 91 | #define BNB_MAX_BLOCKS_PER_SM 16 |
43 | 92 | #else |
44 | 93 | #define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2) |
|
0 commit comments