@@ -3696,6 +3696,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
36963696}
36973697
36983698static bool ggml_cuda_set_cuda_graph_enabled (ggml_backend_cuda_context * cuda_ctx) {
3699+
36993700#ifdef USE_CUDA_GRAPH
37003701 static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
37013702
@@ -3736,17 +3737,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
37363737static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
37373738 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context ;
37383739
3740+ ggml_cuda_set_device (cuda_ctx->device );
3741+
37393742 bool use_cuda_graph = false ;
37403743 bool cuda_graph_update_required = false ;
37413744
37423745 // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
37433746 // we call it here instead.
37443747#ifdef USE_CUDA_GRAPH
3745- if (!cuda_ctx->cuda_graph ) {
3746- use_cuda_graph = ggml_cuda_set_cuda_graph_enabled (cuda_ctx);
3747- } else {
3748- use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph ->cuda_graphs_enabled ;
3749- }
3748+ use_cuda_graph = ggml_cuda_set_cuda_graph_enabled (cuda_ctx);
37503749
37513750 if (use_cuda_graph) {
37523751 cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph);
@@ -3762,6 +3761,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
37623761
37633762 if (cuda_ctx->cuda_graph ->number_consecutive_updates >= 4 ) {
37643763 cuda_ctx->cuda_graph ->disable_due_to_too_many_updates = true ;
3764+ cuda_ctx->cuda_graph ->cuda_graphs_enabled = false ;
37653765#ifndef NDEBUG
37663766 GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to too many consecutive updates\n " , __func__);
37673767#endif
0 commit comments