Skip to content

Commit 88d5f8f

Browse files
authored
CUDA/HIP: Fix kernel slection for mmvq mmid kernel to align host selection with device launch bounds (ggml-org#21238)
The conditions cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE and cc >= GGML_CUDA_CC_TURING match all non-nvidia devices. This causes us to attempt to launch the kernel for batch sizes with larger configurations than our launch bounds on HIP devices. This pr fixes the conditionals in get_mmvq_mmid_max_batch. Fixes ggml-org#21191
1 parent d43375f commit 88d5f8f

1 file changed

Lines changed: 23 additions & 20 deletions

File tree

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
235235
// Host function: returns the max batch size for the current arch+type at runtime.
236236
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
237237
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
238-
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
239-
return MMVQ_MAX_BATCH_SIZE;
240-
}
241-
if (cc >= GGML_CUDA_CC_TURING) {
242-
return get_mmvq_mmid_max_batch_turing_plus(type);
243-
}
244238
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
239+
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
240+
return MMVQ_MAX_BATCH_SIZE;
241+
}
242+
if (cc >= GGML_CUDA_CC_TURING) {
243+
return get_mmvq_mmid_max_batch_turing_plus(type);
244+
}
245245
return get_mmvq_mmid_max_batch_pascal_older(type);
246246
}
247+
247248
// AMD
248-
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
249-
return get_mmvq_mmid_max_batch_rdna4(type);
250-
}
251-
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
252-
return get_mmvq_mmid_max_batch_rdna3(type);
253-
}
254-
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
255-
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
256-
}
257-
if (GGML_CUDA_CC_IS_CDNA(cc)) {
258-
return get_mmvq_mmid_max_batch_cdna(type);
259-
}
260-
if (GGML_CUDA_CC_IS_GCN(cc)) {
261-
return get_mmvq_mmid_max_batch_gcn(type);
249+
if (GGML_CUDA_CC_IS_AMD(cc)) {
250+
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
251+
return get_mmvq_mmid_max_batch_rdna4(type);
252+
}
253+
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
254+
return get_mmvq_mmid_max_batch_rdna3(type);
255+
}
256+
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
257+
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
258+
}
259+
if (GGML_CUDA_CC_IS_CDNA(cc)) {
260+
return get_mmvq_mmid_max_batch_cdna(type);
261+
}
262+
if (GGML_CUDA_CC_IS_GCN(cc)) {
263+
return get_mmvq_mmid_max_batch_gcn(type);
264+
}
262265
}
263266
return MMVQ_MAX_BATCH_SIZE;
264267
}

0 commit comments

Comments
 (0)