Skip to content

Commit 83892a5

Browse files
committed
Fix kernel dispatching for RDNA
1 parent dee600c commit 83892a5

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

csrc/ops.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,16 @@ void quantizeBlockwise(
5252
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
5353
else if (blocksize == 64) {
5454
#if BNB_HIP
55-
// On HIP with 64-wide warps (CDNA), use specialized kernel for 4-bit types
5655
if constexpr (DATA_TYPE > 0) {
57-
kQuantizeBlockwiseSmall<T, DATA_TYPE>
58-
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
56+
if (bnb_host_warp_size() == 64) {
57+
// CDNA: kQuantizeBlockwiseSmall is compiled with THREADS=64
58+
kQuantizeBlockwiseSmall<T, DATA_TYPE>
59+
<<<(num_blocks + 1) / 2, 64>>>(code, A, absmax, out, rand, rand_offset, n);
60+
} else {
61+
// RDNA: standard kernel (same as CUDA path)
62+
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>
63+
<<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
64+
}
5965
} else {
6066
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
6167
}

0 commit comments

Comments
 (0)