Skip to content

Commit 131e3cb

Browse files
committed
Revert "cuda : enable CUDA graphs for MMID 1 <= BS <= 4 (ggml-org#19645)"
This reverts commit ad8207a.
1 parent 81065fd commit 131e3cb

2 files changed

Lines changed: 33 additions & 11 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,12 +2286,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22862286

22872287
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
22882288

2289-
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
22902289
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
22912290
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
22922291
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
22932292
if (ggml_is_quantized(src0->type)) {
2294-
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
2293+
if (ne2 <= 4) {
22952294
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
22962295
return;
22972296
}
@@ -2314,8 +2313,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23142313
}
23152314
}
23162315

2317-
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
2318-
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
23192316
cudaStream_t stream = ctx.stream();
23202317

23212318
GGML_ASSERT(nb12 % nb11 == 0);
@@ -2880,6 +2877,15 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
28802877
bool use_cuda_graph = true;
28812878
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
28822879

2880+
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2881+
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2882+
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2883+
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2884+
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2885+
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2886+
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2887+
const std::string delta_net_prefix = "dnet_add";
2888+
28832889
for (int i = 0; i < cgraph->n_nodes; i++) {
28842890
ggml_tensor * node = cgraph->nodes[i];
28852891

@@ -2894,17 +2900,34 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
28942900
#endif
28952901
}
28962902

2897-
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
2898-
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
2899-
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
2900-
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
2901-
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
2902-
use_cuda_graph = false;
2903+
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
2904+
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
29032905
#ifndef NDEBUG
29042906
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
29052907
#endif
29062908
}
29072909

2910+
if (node->op == GGML_OP_ADD &&
2911+
node->src[1] && node->src[1]->ne[1] > 1 &&
2912+
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2913+
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2914+
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2915+
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2916+
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2917+
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2918+
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
2919+
strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) {
2920+
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2921+
// by means of matching node names. See
2922+
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2923+
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2924+
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2925+
use_cuda_graph = false;
2926+
#ifndef NDEBUG
2927+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2928+
#endif
2929+
}
2930+
29082931
if (!use_cuda_graph) {
29092932
break;
29102933
}

ggml/src/ggml-cuda/mmvq.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "common.cuh"
22

33
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
4-
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
54

65
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
76
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);

0 commit comments

Comments
 (0)