Skip to content

Commit 0b33411

Browse files
TimDettmersclaude
andcommitted
Use conditional load/store algo for warp size compatibility
BLOCK_LOAD_WARP_TRANSPOSE requires threads >= warp_size. On CDNA (warp=64), kQuantizeBlockwise with BLOCK_SIZE=64 has only 32 threads. Fall back to BLOCK_LOAD_DIRECT / BLOCK_STORE_DIRECT when threads < BNB_WARP_SIZE. This avoids rocprim compilation errors while keeping WARP_TRANSPOSE for larger block sizes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c538ced commit 0b33411

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

csrc/kernels.cu

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,22 @@ __global__ void kQuantizeBlockwise(
327327
float local_abs_max = 0.0f;
328328
int local_rand_idx = 0;
329329

330-
typedef bnb_cub::BlockLoad<T, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
330+
// WARP_TRANSPOSE requires block_dim >= warp_size. On CDNA (warp=64),
331+
// block_dim=32 (from BLOCK_SIZE=64/NUM_PER_TH=2) is too small. Fall back
332+
// to DIRECT load/store in that case.
333+
static constexpr int THREADS = BLOCK_SIZE / NUM_PER_TH;
334+
static constexpr auto LOAD_ALGO = (THREADS >= BNB_WARP_SIZE)
335+
? bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE : bnb_cub::BLOCK_LOAD_DIRECT;
336+
static constexpr auto STORE_ALGO = (THREADS >= BNB_WARP_SIZE)
337+
? bnb_cub::BLOCK_STORE_WARP_TRANSPOSE : bnb_cub::BLOCK_STORE_DIRECT;
338+
339+
typedef bnb_cub::BlockLoad<T, THREADS, NUM_PER_TH, LOAD_ALGO> LoadT;
331340
typedef bnb_cub::BlockStore<
332-
unsigned char, BLOCK_SIZE / NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH,
333-
bnb_cub::BLOCK_STORE_WARP_TRANSPOSE>
341+
unsigned char, THREADS, (DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH,
342+
STORE_ALGO>
334343
StoreChar;
335-
typedef bnb_cub::BlockReduce<float, BLOCK_SIZE / NUM_PER_TH> BlockReduce;
336-
typedef bnb_cub::BlockLoad<float, BLOCK_SIZE / NUM_PER_TH, NUM_PER_TH, bnb_cub::BLOCK_LOAD_WARP_TRANSPOSE>
344+
typedef bnb_cub::BlockReduce<float, THREADS> BlockReduce;
345+
typedef bnb_cub::BlockLoad<float, THREADS, NUM_PER_TH, LOAD_ALGO>
337346
LoadFloat;
338347

339348
__shared__ typename LoadT::TempStorage loadt;

0 commit comments

Comments
 (0)