@@ -3308,6 +3308,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod
33083308 return true ;
33093309}
33103310
3311+ // returns whether the write (out) nodes overwrite the read nodes in operation
3312+ static bool ggml_cuda_check_fusion_memory_ranges (const ggml_cgraph * cgraph,
3313+ const int node_idx,
3314+ const int node_count,
3315+ const int * out_nodes,
3316+ const int out_count,
3317+ const bool is_topk_moe = false ) {
3318+ auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3319+ const int64_t a_start = (int64_t ) a->data ;
3320+ const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size (a->buffer ->buft , a);
3321+
3322+ const int64_t b_start = (int64_t ) b->data ;
3323+ const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size (b->buffer ->buft , b);
3324+
3325+ if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3326+ return true ;
3327+ }
3328+
3329+ return false ;
3330+ };
3331+
3332+ bool is_ok = true ;
3333+ // exception for topk-moe, as each row is read entirely before writing
3334+ if (ggml_nrows (cgraph->nodes [node_idx]) == 1 && is_topk_moe) {
3335+ return true ;
3336+ }
3337+
3338+ for (int i = 0 ; i < out_count; ++i) {
3339+ const ggml_tensor * dst = cgraph->nodes [out_nodes[i]];
3340+
3341+ for (int j = node_idx; j < node_idx + node_count; ++j) {
3342+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
3343+ // the destination and the src is not an intermediate node that's being
3344+ // elided, then disable fusion.
3345+
3346+ for (int src_idx = 0 ; src_idx < GGML_MAX_SRC; ++src_idx) {
3347+ const ggml_tensor * src = cgraph->nodes [j]->src [src_idx];
3348+
3349+ if (!src || src->op == GGML_OP_NONE) {
3350+ continue ;
3351+ }
3352+
3353+ if (nodes_overlap (dst, src)) {
3354+ bool found = false ;
3355+
3356+ for (int k = node_idx; k < j; ++k) {
3357+ if (cgraph->nodes [k] == src) {
3358+ found = true ;
3359+ break ;
3360+ }
3361+ }
3362+
3363+ if (!found) {
3364+ is_ok = false ;
3365+ break ;
3366+ }
3367+ }
3368+ }
3369+ }
3370+ }
3371+
3372+ return is_ok;
3373+ }
3374+
3375+
33113376static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph,
33123377 int node_idx,
33133378 std::initializer_list<enum ggml_op> ops,
@@ -3337,7 +3402,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
33373402 const ggml_tensor * glu = cgraph->nodes [node_idx + 4 ];
33383403
33393404 if (ggml_cuda_should_fuse_mul_mat (ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3340- return true ;
3405+ int out_nodes[] = { node_idx + 4 };
3406+ return ggml_cuda_check_fusion_memory_ranges (cgraph, node_idx, (int )ops.size (), out_nodes, 1 );
33413407 }
33423408 }
33433409
@@ -3348,7 +3414,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
33483414 const ggml_tensor * glu = cgraph->nodes [node_idx + 2 ];
33493415
33503416 if (ggml_cuda_should_fuse_mul_mat (ffn_up, ffn_gate, glu)) {
3351- return true ;
3417+ int out_nodes[] = { node_idx + 2 };
3418+ return ggml_cuda_check_fusion_memory_ranges (cgraph, node_idx, (int )ops.size (), out_nodes, 1 );
33523419 }
33533420 }
33543421
@@ -3474,69 +3541,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
34743541 return false ;
34753542}
34763543
3477- // returns whether the write (out) nodes overwrite the read nodes in operation
3478- static bool ggml_cuda_check_fusion_memory_ranges (ggml_cgraph * cgraph,
3479- int node_idx,
3480- int node_count,
3481- int * out_nodes,
3482- int out_count) {
3483- auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3484- const int64_t a_start = (int64_t ) a->data ;
3485- const int64_t a_end = a_start + ggml_nbytes (a);
3486-
3487- const int64_t b_start = (int64_t ) b->data ;
3488- const int64_t b_end = b_start + ggml_nbytes (b);
3489-
3490- if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3491- return true ;
3492- }
3493-
3494- return false ;
3495- };
3496-
3497- bool is_ok = true ;
3498- // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
3499- if (ggml_nrows (cgraph->nodes [node_idx]) == 1 ) {
3500- return true ;
3501- }
3502-
3503- for (int i = 0 ; i < out_count; ++i) {
3504- const ggml_tensor * dst = cgraph->nodes [out_nodes[i]];
3505-
3506- for (int j = node_idx; j < node_idx + node_count; ++j) {
3507- // Loop over all srcs of all nodes in the fusion. If the src overlaps
3508- // the destination and the src is not an intermediate node that's being
3509- // elided, then disable fusion.
3510-
3511- for (int src_idx = 0 ; src_idx < GGML_MAX_SRC; ++src_idx) {
3512- const ggml_tensor * src = cgraph->nodes [j]->src [src_idx];
3513-
3514- if (!src || src->op == GGML_OP_NONE) {
3515- continue ;
3516- }
3517-
3518- if (nodes_overlap (dst, src)) {
3519- bool found = false ;
3520-
3521- for (int k = node_idx; k < j; ++k) {
3522- if (cgraph->nodes [k] == src) {
3523- found = true ;
3524- break ;
3525- }
3526- }
3527-
3528- if (!found) {
3529- is_ok = false ;
3530- break ;
3531- }
3532- }
3533- }
3534- }
3535- }
3536-
3537- return is_ok;
3538- }
3539-
35403544static void ggml_cuda_graph_evaluate_and_capture (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
35413545 bool graph_evaluated_or_captured = false ;
35423546
@@ -3734,7 +3738,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37343738
37353739 if (ggml_can_fuse_subgraph (cgraph, i, ops.size (), ops.data (), out_nodes, 2 ) &&
37363740 ggml_cuda_should_use_topk_moe (node, logits, weights, ids) &&
3737- ggml_cuda_check_fusion_memory_ranges (cgraph, i, ops.size (), out_nodes, 2 )) {
3741+ ggml_cuda_check_fusion_memory_ranges (cgraph, i, ops.size (), out_nodes, 2 , /* is_topk_moe= */ true )) {
37383742 ggml_cuda_op_topk_moe (*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
37393743 i += ops.size () - 1 ;
37403744 continue ;
@@ -3750,7 +3754,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37503754 int out_nodes[2 ] = { i + 1 , i + 5 };
37513755 if (ggml_can_fuse_subgraph (cgraph, i, ops.size (), ops.data (), out_nodes, 2 ) &&
37523756 ggml_cuda_should_use_topk_moe (softmax, logits, weights, ids) &&
3753- ggml_cuda_check_fusion_memory_ranges (cgraph, i, ops.size (), out_nodes, 2 )) {
3757+ ggml_cuda_check_fusion_memory_ranges (cgraph, i, ops.size (), out_nodes, 2 , /* is_topk_moe= */ true )) {
37543758 ggml_cuda_op_topk_moe (*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
37553759 i += ops.size () - 1 ;
37563760 continue ;
0 commit comments