Skip to content

Commit de1aa6f

Browse files
authored
CUDA: check for buffer overlap before fusing (#21566)
* CUDA: check for buffer overlap before fusing * use ggml_cuda_check_fusion_memory_ranges
1 parent 69c28f1 commit de1aa6f

File tree

1 file changed

+71
-67
lines changed

1 file changed

+71
-67
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 71 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3308,6 +3308,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod
33083308
return true;
33093309
}
33103310

3311+
// returns whether the write (out) nodes overwrite the read nodes in operation
3312+
static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph,
3313+
const int node_idx,
3314+
const int node_count,
3315+
const int * out_nodes,
3316+
const int out_count,
3317+
const bool is_topk_moe = false) {
3318+
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3319+
const int64_t a_start = (int64_t) a->data;
3320+
const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a);
3321+
3322+
const int64_t b_start = (int64_t) b->data;
3323+
const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b);
3324+
3325+
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3326+
return true;
3327+
}
3328+
3329+
return false;
3330+
};
3331+
3332+
bool is_ok = true;
3333+
// exception for topk-moe, as each row is read entirely before writing
3334+
if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) {
3335+
return true;
3336+
}
3337+
3338+
for (int i = 0; i < out_count; ++i) {
3339+
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3340+
3341+
for (int j = node_idx; j < node_idx + node_count; ++j) {
3342+
// Loop over all srcs of all nodes in the fusion. If the src overlaps
3343+
// the destination and the src is not an intermediate node that's being
3344+
// elided, then disable fusion.
3345+
3346+
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3347+
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3348+
3349+
if (!src || src->op == GGML_OP_NONE) {
3350+
continue;
3351+
}
3352+
3353+
if (nodes_overlap(dst, src)) {
3354+
bool found = false;
3355+
3356+
for (int k = node_idx; k < j; ++k) {
3357+
if (cgraph->nodes[k] == src) {
3358+
found = true;
3359+
break;
3360+
}
3361+
}
3362+
3363+
if (!found) {
3364+
is_ok = false;
3365+
break;
3366+
}
3367+
}
3368+
}
3369+
}
3370+
}
3371+
3372+
return is_ok;
3373+
}
3374+
3375+
33113376
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
33123377
int node_idx,
33133378
std::initializer_list<enum ggml_op> ops,
@@ -3337,7 +3402,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
33373402
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
33383403

33393404
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3340-
return true;
3405+
int out_nodes[] = { node_idx + 4 };
3406+
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
33413407
}
33423408
}
33433409

@@ -3348,7 +3414,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
33483414
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
33493415

33503416
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3351-
return true;
3417+
int out_nodes[] = { node_idx + 2 };
3418+
return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1);
33523419
}
33533420
}
33543421

@@ -3474,69 +3541,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
34743541
return false;
34753542
}
34763543

3477-
// returns whether the write (out) nodes overwrite the read nodes in operation
3478-
static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
3479-
int node_idx,
3480-
int node_count,
3481-
int * out_nodes,
3482-
int out_count) {
3483-
auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
3484-
const int64_t a_start = (int64_t) a->data;
3485-
const int64_t a_end = a_start + ggml_nbytes(a);
3486-
3487-
const int64_t b_start = (int64_t) b->data;
3488-
const int64_t b_end = b_start + ggml_nbytes(b);
3489-
3490-
if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
3491-
return true;
3492-
}
3493-
3494-
return false;
3495-
};
3496-
3497-
bool is_ok = true;
3498-
// for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
3499-
if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
3500-
return true;
3501-
}
3502-
3503-
for (int i = 0; i < out_count; ++i) {
3504-
const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
3505-
3506-
for (int j = node_idx; j < node_idx + node_count; ++j) {
3507-
// Loop over all srcs of all nodes in the fusion. If the src overlaps
3508-
// the destination and the src is not an intermediate node that's being
3509-
// elided, then disable fusion.
3510-
3511-
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3512-
const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
3513-
3514-
if (!src || src->op == GGML_OP_NONE) {
3515-
continue;
3516-
}
3517-
3518-
if (nodes_overlap(dst, src)) {
3519-
bool found = false;
3520-
3521-
for (int k = node_idx; k < j; ++k) {
3522-
if (cgraph->nodes[k] == src) {
3523-
found = true;
3524-
break;
3525-
}
3526-
}
3527-
3528-
if (!found) {
3529-
is_ok = false;
3530-
break;
3531-
}
3532-
}
3533-
}
3534-
}
3535-
}
3536-
3537-
return is_ok;
3538-
}
3539-
35403544
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
35413545
bool graph_evaluated_or_captured = false;
35423546

@@ -3734,7 +3738,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37343738

37353739
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
37363740
ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
3737-
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3741+
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
37383742
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
37393743
i += ops.size() - 1;
37403744
continue;
@@ -3750,7 +3754,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
37503754
int out_nodes[2] = { i + 1, i + 5 };
37513755
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
37523756
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
3753-
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
3757+
ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) {
37543758
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
37553759
i += ops.size() - 1;
37563760
continue;

0 commit comments

Comments
 (0)