@@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
28532853}
28542854
28552855#ifdef USE_CUDA_GRAPH
2856- static bool check_node_graph_compatibility (ggml_cgraph * cgraph,
2857- bool use_cuda_graph) {
2856+ static bool ggml_cuda_graph_check_compability (ggml_cgraph * cgraph) {
28582857
2858+ bool use_cuda_graph = true ;
28592859 // Loop over nodes in GGML graph to obtain info needed for CUDA graph
28602860
28612861 const std::string gemma3n_per_layer_proj_src0_name = " inp_per_layer_selected" ;
@@ -2915,98 +2915,97 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
29152915 return use_cuda_graph;
29162916}
29172917
2918- static void set_ggml_graph_node_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties ) {
2919- graph_node_properties ->node_address = node->data ;
2920- graph_node_properties ->node_op = node->op ;
2918+ static void ggml_cuda_graph_node_set_properties (ggml_cuda_graph_node_properties * props, ggml_tensor * node ) {
2919+ props ->node_address = node->data ;
2920+ props ->node_op = node->op ;
29212921 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
2922- graph_node_properties ->ne [i] = node->ne [i];
2923- graph_node_properties ->nb [i] = node->nb [i];
2922+ props ->ne [i] = node->ne [i];
2923+ props ->nb [i] = node->nb [i];
29242924 }
29252925 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2926- graph_node_properties ->src_address [i] = node->src [i] ? node->src [i]->data : nullptr ;
2926+ props ->src_address [i] = node->src [i] ? node->src [i]->data : nullptr ;
29272927 }
2928- memcpy (graph_node_properties ->op_params , node->op_params , GGML_MAX_OP_PARAMS);
2928+ memcpy (props ->op_params , node->op_params , GGML_MAX_OP_PARAMS);
29292929}
29302930
2931- static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties ) {
2932- if (node->data != graph_node_properties ->node_address &&
2931+ static bool ggml_cuda_graph_node_properties_match (ggml_tensor * node, ggml_cuda_graph_node_properties * props ) {
2932+ if (node->data != props ->node_address &&
29332933 node->op != GGML_OP_VIEW) {
29342934 return false ;
29352935 }
29362936
2937- if (node->op != graph_node_properties ->node_op ) {
2937+ if (node->op != props ->node_op ) {
29382938 return false ;
29392939 }
29402940
29412941 for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
2942- if (node->ne [i] != graph_node_properties ->ne [i]) {
2942+ if (node->ne [i] != props ->ne [i]) {
29432943 return false ;
29442944 }
2945- if (node->nb [i] != graph_node_properties ->nb [i]) {
2945+ if (node->nb [i] != props ->nb [i]) {
29462946 return false ;
29472947 }
29482948 }
29492949
29502950 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
29512951 if (node->src [i] &&
2952- node->src [i]->data != graph_node_properties ->src_address [i] &&
2952+ node->src [i]->data != props ->src_address [i] &&
29532953 node->op != GGML_OP_VIEW
29542954 ) {
29552955 return false ;
29562956 }
29572957 }
29582958
29592959 if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
2960- memcmp (graph_node_properties ->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2960+ memcmp (props ->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
29612961 return false ;
29622962 }
29632963
29642964 return true ;
29652965}
29662966
2967- static bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2967+ static bool ggml_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
29682968
2969- bool cuda_graph_update_required = false ;
2969+ bool res = false ;
29702970
29712971 if (cuda_ctx->cuda_graph ->instance == nullptr ) {
2972- cuda_graph_update_required = true ;
2972+ res = true ;
29732973 }
29742974
29752975 // Check if the graph size has changed
2976- if (cuda_ctx->cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph->n_nodes + cgraph->n_leafs ) {
2977- cuda_graph_update_required = true ;
2978- cuda_ctx->cuda_graph ->ggml_graph_properties .resize (cgraph->n_nodes + cgraph->n_leafs );
2976+ if (cuda_ctx->cuda_graph ->props .size () != (size_t )cgraph->n_nodes + cgraph->n_leafs ) {
2977+ res = true ;
2978+ cuda_ctx->cuda_graph ->props .resize (cgraph->n_nodes + cgraph->n_leafs );
29792979 }
29802980
29812981 // Loop over nodes in GGML graph to determine if CUDA graph update is required
29822982 // and store properties to allow this comparison for the next token
29832983 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2984- bool has_matching_properties = true ;
2985-
2986- if (!cuda_graph_update_required) {
2987- has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2984+ bool props_match = true ;
2985+ if (!res) {
2986+ props_match = ggml_cuda_graph_node_properties_match (cgraph->nodes [i], &cuda_ctx->cuda_graph ->props [i]);
29882987 }
2989- if (!has_matching_properties ) {
2990- cuda_graph_update_required = true ;
2988+ if (!props_match ) {
2989+ res = true ;
29912990 }
2992- set_ggml_graph_node_properties (cgraph-> nodes [i], &cuda_ctx-> cuda_graph -> ggml_graph_properties [i]);
2991+ ggml_cuda_graph_node_set_properties (&cuda_ctx-> cuda_graph -> props [i], cgraph-> nodes [i]);
29932992 }
29942993
29952994 for (int i = 0 ; i < cgraph->n_leafs ; i++) {
2996- bool has_matching_properties = true ;
2997- if (!cuda_graph_update_required ) {
2998- has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->leafs [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [cgraph->n_nodes + i]);
2995+ bool props_match = true ;
2996+ if (!res ) {
2997+ props_match = ggml_cuda_graph_node_properties_match (cgraph->leafs [i], &cuda_ctx->cuda_graph ->props [cgraph->n_nodes + i]);
29992998 }
3000- if (!has_matching_properties ) {
3001- cuda_graph_update_required = true ;
2999+ if (!props_match ) {
3000+ res = true ;
30023001 }
3003- set_ggml_graph_node_properties (cgraph-> leafs [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [cgraph->n_nodes + i]);
3002+ ggml_cuda_graph_node_set_properties ( &cuda_ctx->cuda_graph ->props [cgraph->n_nodes + i], cgraph-> leafs [ i]);
30043003 }
30053004
3006- return cuda_graph_update_required ;
3005+ return res ;
30073006}
30083007
3009- static void update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx) {
3008+ static void ggml_cuda_graph_update_executable (ggml_backend_cuda_context * cuda_ctx) {
30103009
30113010#if CUDART_VERSION >= 12000
30123011 cudaGraphExecUpdateResultInfo result_info;
@@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
32373236 return false ;
32383237}
32393238
3240- static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
3241- bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
3239+ static void ggml_cuda_graph_evaluate_and_capture (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
3240+ bool graph_evaluated_or_captured = false ;
3241+
32423242 // flag used to determine whether it is an integrated_gpu
3243- const bool integrated = ggml_cuda_info ().devices [cuda_ctx->device ].integrated ;
3243+ const bool integrated = ggml_cuda_info ().devices [cuda_ctx->device ].integrated ;
32443244
32453245 ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context ();
32463246 bool is_concurrent_event_active = false ;
@@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37103710 CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
37113711 }
37123712 if (cuda_graph_update_required) { // Update graph executable
3713- update_cuda_graph_executable (cuda_ctx);
3713+ ggml_cuda_graph_update_executable (cuda_ctx);
37143714 }
37153715 // Launch graph
37163716 CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
@@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37203720 }
37213721}
37223722
3723- static bool ggml_cuda_set_cuda_graph_enabled (ggml_backend_cuda_context * cuda_ctx) {
3723+ static bool ggml_cuda_graph_set_enabled (ggml_backend_cuda_context * cuda_ctx) {
37243724
37253725#ifdef USE_CUDA_GRAPH
3726- static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
37273726
3728- // Objects required for CUDA Graph
37293727 if (cuda_ctx->cuda_graph == nullptr ) {
37303728 cuda_ctx->cuda_graph .reset (new ggml_cuda_graph ());
37313729 }
37323730
3733- bool use_cuda_graph = true ;
3734-
37353731 if (cuda_ctx->cuda_graph ->graph == nullptr ) {
37363732 if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < GGML_CUDA_CC_AMPERE) {
37373733 cuda_ctx->cuda_graph ->disable_due_to_gpu_arch = true ;
3738- #ifndef NDEBUG
37393734 GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to GPU architecture\n " , __func__);
3740- #endif
37413735 }
37423736 }
37433737
3744- // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
3745- // or previous graph capture failure.
3746- // Also disable for multi-gpu for now. TO DO investigate
3747- if (disable_cuda_graphs_due_to_env
3748- || cuda_ctx->cuda_graph ->disable_due_to_gpu_arch
3749- || cuda_ctx->cuda_graph ->disable_due_to_too_many_updates
3750- || cuda_ctx->cuda_graph ->disable_due_to_failed_graph_capture ) {
3751- use_cuda_graph = false ;
3752- }
3753-
3754- cuda_ctx->cuda_graph ->cuda_graphs_enabled = use_cuda_graph;
3738+ return cuda_ctx->cuda_graph ->is_enabled ();
37553739#else
3756- bool use_cuda_graph = false ;
3740+ return false ;
37573741#endif // USE_CUDA_GRAPH
3758-
3759- return use_cuda_graph;
37603742}
37613743
37623744static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
@@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
37673749 bool use_cuda_graph = false ;
37683750 bool cuda_graph_update_required = false ;
37693751
3770- // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
3771- // we call it here instead.
37723752#ifdef USE_CUDA_GRAPH
3773- use_cuda_graph = ggml_cuda_set_cuda_graph_enabled (cuda_ctx);
3774-
3775- if (use_cuda_graph) {
3776- cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph);
3777-
3778- use_cuda_graph = check_node_graph_compatibility (cgraph, use_cuda_graph);
3753+ use_cuda_graph = ggml_cuda_graph_set_enabled (cuda_ctx);
37793754
3780- // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
3781- if (use_cuda_graph && cuda_graph_update_required) {
3782- cuda_ctx->cuda_graph ->number_consecutive_updates ++;
3783- } else {
3784- cuda_ctx->cuda_graph ->number_consecutive_updates = 0 ;
3785- }
3755+ if (cuda_ctx->cuda_graph ->is_enabled ()) {
3756+ cuda_graph_update_required = ggml_cuda_graph_update_required (cuda_ctx, cgraph);
3757+ use_cuda_graph = ggml_cuda_graph_check_compability (cgraph);
37863758
3787- if (cuda_ctx->cuda_graph ->number_consecutive_updates >= 4 ) {
3788- cuda_ctx->cuda_graph ->disable_due_to_too_many_updates = true ;
3789- cuda_ctx->cuda_graph ->cuda_graphs_enabled = false ;
3790- #ifndef NDEBUG
3791- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to too many consecutive updates\n " , __func__);
3792- #endif
3793- }
3759+ cuda_ctx->cuda_graph ->record_update (use_cuda_graph, cuda_graph_update_required);
37943760 }
37953761#endif // USE_CUDA_GRAPH
37963762
@@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
38043770 CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
38053771 }
38063772
3807- bool graph_evaluated_or_captured = false ;
3808-
3809- evaluate_and_capture_cuda_graph (cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
3773+ ggml_cuda_graph_evaluate_and_capture (cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
38103774
38113775 return GGML_STATUS_SUCCESS;
38123776}
@@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
38393803static void ggml_backend_cuda_graph_optimize (ggml_backend_t backend, ggml_cgraph * cgraph) {
38403804 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context ;
38413805
3842- const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled (cuda_ctx);
3806+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled (cuda_ctx);
38433807
38443808 static bool enable_graph_optimization = [] {
38453809 const char * env = getenv (" GGML_CUDA_GRAPH_OPT" );
0 commit comments