@@ -3708,6 +3708,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37083708}
37093709
37103710static bool ggml_cuda_set_cuda_graph_enabled (ggml_backend_cuda_context * cuda_ctx) {
3711+
37113712#ifdef USE_CUDA_GRAPH
37123713 static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
37133714
@@ -3748,17 +3749,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
37483749static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
37493750 ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context ;
37503751
3752+ ggml_cuda_set_device (cuda_ctx->device );
3753+
37513754 bool use_cuda_graph = false ;
37523755 bool cuda_graph_update_required = false ;
37533756
37543757 // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
37553758 // we call it here instead.
37563759#ifdef USE_CUDA_GRAPH
3757- if (!cuda_ctx->cuda_graph ) {
3758- use_cuda_graph = ggml_cuda_set_cuda_graph_enabled (cuda_ctx);
3759- } else {
3760- use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph ->cuda_graphs_enabled ;
3761- }
3760+ use_cuda_graph = ggml_cuda_set_cuda_graph_enabled (cuda_ctx);
37623761
37633762 if (use_cuda_graph) {
37643763 cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph);
@@ -3774,6 +3773,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
37743773
37753774 if (cuda_ctx->cuda_graph ->number_consecutive_updates >= 4 ) {
37763775 cuda_ctx->cuda_graph ->disable_due_to_too_many_updates = true ;
3776+ cuda_ctx->cuda_graph ->cuda_graphs_enabled = false ;
37773777#ifndef NDEBUG
37783778 GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to too many consecutive updates\n " , __func__);
37793779#endif
0 commit comments