Skip to content

Commit 090b137

Browse files
authored
ggml-cuda: refactor cuda graph usage (ggml-org#18637)
* ggml-cuda: refactor cuda graph usage * use is_enabled() instead of enabled
1 parent 9689295 commit 090b137

3 files changed

Lines changed: 72 additions & 96 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ struct ggml_tensor_extra_gpu {
10361036
#define USE_CUDA_GRAPH
10371037
#endif
10381038

1039-
struct ggml_graph_node_properties {
1039+
struct ggml_cuda_graph_node_properties {
10401040
void * node_address;
10411041
ggml_op node_op;
10421042
int64_t ne[GGML_MAX_DIMS];
@@ -1061,11 +1061,25 @@ struct ggml_cuda_graph {
10611061
std::vector<cudaGraphNode_t> nodes;
10621062
bool disable_due_to_gpu_arch = false;
10631063
bool disable_due_to_too_many_updates = false;
1064-
bool disable_due_to_failed_graph_capture = false;
10651064
int number_consecutive_updates = 0;
1066-
bool cuda_graphs_enabled = false;
1067-
std::vector<ggml_graph_node_properties> ggml_graph_properties;
1068-
std::vector<ggml_graph_node_properties> extraneous_srcs_properties;
1065+
std::vector<ggml_cuda_graph_node_properties> props;
1066+
1067+
void record_update(bool use_graph, bool update_required) {
1068+
if (use_graph && update_required) {
1069+
number_consecutive_updates++;
1070+
} else {
1071+
number_consecutive_updates = 0;
1072+
}
1073+
if (number_consecutive_updates >= 4) {
1074+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
1075+
disable_due_to_too_many_updates = true;
1076+
}
1077+
}
1078+
1079+
bool is_enabled() const {
1080+
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
1081+
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
1082+
}
10691083
#endif
10701084
};
10711085

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 51 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
28532853
}
28542854

28552855
#ifdef USE_CUDA_GRAPH
2856-
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
2857-
bool use_cuda_graph) {
2856+
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
28582857

2858+
bool use_cuda_graph = true;
28592859
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
28602860

28612861
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
@@ -2915,98 +2915,97 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
29152915
return use_cuda_graph;
29162916
}
29172917

2918-
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2919-
graph_node_properties->node_address = node->data;
2920-
graph_node_properties->node_op = node->op;
2918+
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2919+
props->node_address = node->data;
2920+
props->node_op = node->op;
29212921
for (int i = 0; i < GGML_MAX_DIMS; i++) {
2922-
graph_node_properties->ne[i] = node->ne[i];
2923-
graph_node_properties->nb[i] = node->nb[i];
2922+
props->ne[i] = node->ne[i];
2923+
props->nb[i] = node->nb[i];
29242924
}
29252925
for (int i = 0; i < GGML_MAX_SRC; i++) {
2926-
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
2926+
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
29272927
}
2928-
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2928+
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
29292929
}
29302930

2931-
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2932-
if (node->data != graph_node_properties->node_address &&
2931+
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2932+
if (node->data != props->node_address &&
29332933
node->op != GGML_OP_VIEW) {
29342934
return false;
29352935
}
29362936

2937-
if (node->op != graph_node_properties->node_op) {
2937+
if (node->op != props->node_op) {
29382938
return false;
29392939
}
29402940

29412941
for (int i = 0; i < GGML_MAX_DIMS; i++) {
2942-
if (node->ne[i] != graph_node_properties->ne[i]) {
2942+
if (node->ne[i] != props->ne[i]) {
29432943
return false;
29442944
}
2945-
if (node->nb[i] != graph_node_properties->nb[i]) {
2945+
if (node->nb[i] != props->nb[i]) {
29462946
return false;
29472947
}
29482948
}
29492949

29502950
for (int i = 0; i < GGML_MAX_SRC; i++) {
29512951
if (node->src[i] &&
2952-
node->src[i]->data != graph_node_properties->src_address[i] &&
2952+
node->src[i]->data != props->src_address[i] &&
29532953
node->op != GGML_OP_VIEW
29542954
) {
29552955
return false;
29562956
}
29572957
}
29582958

29592959
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
2960-
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2960+
memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
29612961
return false;
29622962
}
29632963

29642964
return true;
29652965
}
29662966

2967-
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2967+
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
29682968

2969-
bool cuda_graph_update_required = false;
2969+
bool res = false;
29702970

29712971
if (cuda_ctx->cuda_graph->instance == nullptr) {
2972-
cuda_graph_update_required = true;
2972+
res = true;
29732973
}
29742974

29752975
// Check if the graph size has changed
2976-
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
2977-
cuda_graph_update_required = true;
2978-
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs);
2976+
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
2977+
res = true;
2978+
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
29792979
}
29802980

29812981
// Loop over nodes in GGML graph to determine if CUDA graph update is required
29822982
// and store properties to allow this comparison for the next token
29832983
for (int i = 0; i < cgraph->n_nodes; i++) {
2984-
bool has_matching_properties = true;
2985-
2986-
if (!cuda_graph_update_required) {
2987-
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2984+
bool props_match = true;
2985+
if (!res) {
2986+
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
29882987
}
2989-
if (!has_matching_properties) {
2990-
cuda_graph_update_required = true;
2988+
if (!props_match) {
2989+
res = true;
29912990
}
2992-
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2991+
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
29932992
}
29942993

29952994
for (int i = 0; i < cgraph->n_leafs; i++) {
2996-
bool has_matching_properties = true;
2997-
if (!cuda_graph_update_required) {
2998-
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
2995+
bool props_match= true;
2996+
if (!res) {
2997+
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
29992998
}
3000-
if (!has_matching_properties) {
3001-
cuda_graph_update_required = true;
2999+
if (!props_match) {
3000+
res = true;
30023001
}
3003-
set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]);
3002+
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
30043003
}
30053004

3006-
return cuda_graph_update_required;
3005+
return res;
30073006
}
30083007

3009-
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
3008+
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
30103009

30113010
#if CUDART_VERSION >= 12000
30123011
cudaGraphExecUpdateResultInfo result_info;
@@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
32373236
return false;
32383237
}
32393238

3240-
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
3241-
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
3239+
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
3240+
bool graph_evaluated_or_captured = false;
3241+
32423242
// flag used to determine whether it is an integrated_gpu
3243-
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3243+
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
32443244

32453245
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
32463246
bool is_concurrent_event_active = false;
@@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37103710
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
37113711
}
37123712
if (cuda_graph_update_required) { // Update graph executable
3713-
update_cuda_graph_executable(cuda_ctx);
3713+
ggml_cuda_graph_update_executable(cuda_ctx);
37143714
}
37153715
// Launch graph
37163716
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
@@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37203720
}
37213721
}
37223722

3723-
static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
3723+
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
37243724

37253725
#ifdef USE_CUDA_GRAPH
3726-
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
37273726

3728-
// Objects required for CUDA Graph
37293727
if (cuda_ctx->cuda_graph == nullptr) {
37303728
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
37313729
}
37323730

3733-
bool use_cuda_graph = true;
3734-
37353731
if (cuda_ctx->cuda_graph->graph == nullptr) {
37363732
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
37373733
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
3738-
#ifndef NDEBUG
37393734
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
3740-
#endif
37413735
}
37423736
}
37433737

3744-
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3745-
// or previous graph capture failure.
3746-
// Also disable for multi-gpu for now. TO DO investigate
3747-
if (disable_cuda_graphs_due_to_env
3748-
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
3749-
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
3750-
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
3751-
use_cuda_graph = false;
3752-
}
3753-
3754-
cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph;
3738+
return cuda_ctx->cuda_graph->is_enabled();
37553739
#else
3756-
bool use_cuda_graph = false;
3740+
return false;
37573741
#endif // USE_CUDA_GRAPH
3758-
3759-
return use_cuda_graph;
37603742
}
37613743

37623744
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
@@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
37673749
bool use_cuda_graph = false;
37683750
bool cuda_graph_update_required = false;
37693751

3770-
// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
3771-
// we call it here instead.
37723752
#ifdef USE_CUDA_GRAPH
3773-
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
3774-
3775-
if (use_cuda_graph) {
3776-
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
3777-
3778-
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
3753+
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
37793754

3780-
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
3781-
if (use_cuda_graph && cuda_graph_update_required) {
3782-
cuda_ctx->cuda_graph->number_consecutive_updates++;
3783-
} else {
3784-
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
3785-
}
3755+
if (cuda_ctx->cuda_graph->is_enabled()) {
3756+
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
3757+
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
37863758

3787-
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
3788-
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
3789-
cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
3790-
#ifndef NDEBUG
3791-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
3792-
#endif
3793-
}
3759+
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
37943760
}
37953761
#endif // USE_CUDA_GRAPH
37963762

@@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
38043770
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
38053771
}
38063772

3807-
bool graph_evaluated_or_captured = false;
3808-
3809-
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
3773+
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
38103774

38113775
return GGML_STATUS_SUCCESS;
38123776
}
@@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
38393803
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
38403804
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
38413805

3842-
const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
3806+
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
38433807

38443808
static bool enable_graph_optimization = [] {
38453809
const char * env = getenv("GGML_CUDA_GRAPH_OPT");

ggml/src/ggml-cuda/mean.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3434
// CUDA_GRAPHS_DISABLED
3535
((ncols > 65536) &&
3636
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
37-
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
38-
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
37+
ctx.cuda_graph->is_enabled())) ||
3938
// CUDA_GRAPHS ENABLED
4039
((ncols > 32768) &&
4140
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
42-
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
43-
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
41+
ctx.cuda_graph->is_enabled()))) {
4442
#else
4543
(ncols > 65536)) {
4644
#endif // USE_CUDA_GRAPH

0 commit comments

Comments
 (0)