Skip to content

Commit 373f23b

Browse files
authored
[ROCm] Replace compile-time warp size with runtime query in host code (#1885)
* Replace compile-time warp size with runtime query in host code Add bnb_host_warp_size() that queries hipDeviceGetAttribute at runtime with per-device caching (up to 32 GPUs), replacing the compile-time BNB_WARP_SIZE macro in host-side dispatch. This fixes incorrect defaulting to warp size 64 on RDNA and kernel dispatch with proper parameters. * Fix kernel dispatching for RDNA * Fix linting issues * Fix linting issues * Fix linting issues * Revert device array caching and instead only do device 0 * Use atomics to avoid a race condition * Fix linting issues
1 parent e63e29c commit 373f23b

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

csrc/common.cuh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
// Warp size
88

99
#if BNB_HIP
10-
// CDNA (gfx9xx) = 64, RDNA = 32.
10+
// CDNA (gfx9xx) = 64, RDNA (gfx10xx/gfx11xx/gfx12xx) = 32.
11+
// __AMDGCN_WAVEFRONT_SIZE is not defined by all compiler versions (removed since ROCm 7.0),
12+
// so fall back to architecture-family macros when it is absent.
13+
// This is a macro that is defined by the compiler during each device-code pass and as such
14+
// should only be used inside kernels.
1115
#ifdef __AMDGCN_WAVEFRONT_SIZE
1216
#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
17+
#elif defined(__GFX9__)
18+
#define BNB_WARP_SIZE 64 // CDNA
1319
#else
14-
#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA)
20+
#define BNB_WARP_SIZE 32 // RDNA and other
1521
#endif
1622
#else
1723
#define BNB_WARP_SIZE 32

csrc/ops.cu

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,26 @@
1010

1111
#define ERR_NOT_IMPLEMENTED 100
1212

13+
#if BNB_HIP
14+
#include <atomic>
15+
#include <hip/hip_runtime.h>
16+
17+
// NOTE: This queries device 0 once and caches the result. On mixed RDNA+CDNA
18+
// systems (warp size 32 vs 64) this will return the wrong value for whichever
19+
// device doesn't match device 0.
20+
static int bnb_host_warp_size() {
21+
static std::atomic<int> warp_size{0};
22+
int ws = warp_size.load(std::memory_order_relaxed);
23+
if (ws == 0) {
24+
(void)hipDeviceGetAttribute(&ws, hipDeviceAttributeWarpSize, 0);
25+
warp_size.store(ws, std::memory_order_relaxed);
26+
}
27+
return ws;
28+
}
29+
#else
30+
static constexpr int bnb_host_warp_size() { return 32; }
31+
#endif
32+
1333
using std::cout;
1434
using std::endl;
1535

@@ -35,10 +55,16 @@ void quantizeBlockwise(
3555
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
3656
else if (blocksize == 64) {
3757
#if BNB_HIP
38-
// On HIP with 64-wide warps (CDNA), use specialized kernel for 4-bit types
3958
if constexpr (DATA_TYPE > 0) {
40-
kQuantizeBlockwiseSmall<T, DATA_TYPE>
41-
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
59+
if (bnb_host_warp_size() == 64) {
60+
// CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64
61+
kQuantizeBlockwiseSmall<T, DATA_TYPE>
62+
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
63+
} else {
64+
// RDNA: standard kernel (same as CUDA path)
65+
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>
66+
<<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
67+
}
4268
} else {
4369
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
4470
}
@@ -407,8 +433,7 @@ void gemm_4bit_inference_naive(
407433

408434
int num_blocks = (m + 3) / 4;
409435
#if BNB_HIP
410-
// On 64-wide warp architectures, each warp processes 2 rows instead of 4
411-
if (BNB_WARP_SIZE == 64) {
436+
if (bnb_host_warp_size() == 64) {
412437
num_blocks = (m + 1) / 2;
413438
}
414439
#endif

0 commit comments

Comments
 (0)