Skip to content

Commit f5e9cf3

Browse files
committed
Replace compile-time warp size with runtime query in host code
Add bnb_host_warp_size() that queries hipDeviceGetAttribute at runtime with per-device caching (up to 32 GPUs), replacing the compile-time BNB_WARP_SIZE macro in host-side dispatch. This fixes incorrect GEMV grid sizing on mixed CDNA/RDNA systems where warp size varies per device.
1 parent e63e29c commit f5e9cf3

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

csrc/common.cuh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@
77
// Warp size
88

99
#if BNB_HIP
10-
// CDNA (gfx9xx) = 64, RDNA = 32.
10+
// CDNA (gfx9xx) = 64, RDNA (gfx10xx/gfx11xx/gfx12xx) = 32.
11+
// __AMDGCN_WAVEFRONT_SIZE is not defined by all compiler versions (removed since ROCm 7.0),
12+
// so fall back to architecture-family macros when it is absent.
13+
// This is a macro that is defined by the compiler during each device-code pass and as such should only be used inside kernels.
1114
#ifdef __AMDGCN_WAVEFRONT_SIZE
1215
#define BNB_WARP_SIZE __AMDGCN_WAVEFRONT_SIZE
16+
#elif defined(__GFX9__)
17+
#define BNB_WARP_SIZE 64 // CDNA
1318
#else
14-
#define BNB_WARP_SIZE 64 // Safe default for HIP (matches CDNA)
19+
#define BNB_WARP_SIZE 32 // RDNA and other
1520
#endif
1621
#else
1722
#define BNB_WARP_SIZE 32

csrc/ops.cu

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010

1111
#define ERR_NOT_IMPLEMENTED 100
1212

13+
#if BNB_HIP
14+
#include <hip/hip_runtime.h>
15+
static int bnb_host_warp_size() {
16+
constexpr int MAX_DEVICES = 32;
17+
static int cache[MAX_DEVICES] = {};
18+
int dev;
19+
(void)hipGetDevice(&dev);
20+
if (dev < 0 || dev >= MAX_DEVICES) return 64;
21+
if (cache[dev] == 0)
22+
(void)hipDeviceGetAttribute(&cache[dev], hipDeviceAttributeWarpSize, dev);
23+
return cache[dev];
24+
}
25+
#else
26+
static constexpr int bnb_host_warp_size() { return 32; }
27+
#endif
28+
29+
1330
using std::cout;
1431
using std::endl;
1532

@@ -407,8 +424,7 @@ void gemm_4bit_inference_naive(
407424

408425
int num_blocks = (m + 3) / 4;
409426
#if BNB_HIP
410-
// On 64-wide warp architectures, each warp processes 2 rows instead of 4
411-
if (BNB_WARP_SIZE == 64) {
427+
if (bnb_host_warp_size() == 64) {
412428
num_blocks = (m + 1) / 2;
413429
}
414430
#endif

0 commit comments

Comments
 (0)