Skip to content

Commit ad6c53a

Browse files
committed
Merge commit '908a9e5a1eaaff345f05087beafdf43d31e3f00a' into concedo
2 parents acfc1e5 + 908a9e5 commit ad6c53a

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,6 @@ struct ggml_cuda_graph {
10651065
cudaGraphExec_t instance = nullptr;
10661066
size_t num_nodes = 0;
10671067
std::vector<cudaGraphNode_t> nodes;
1068-
std::vector<cudaKernelNodeParams> params;
10691068
bool disable_due_to_gpu_arch = false;
10701069
bool disable_due_to_too_many_updates = false;
10711070
bool disable_due_to_failed_graph_capture = false;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3708,6 +3708,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
37083708
}
37093709

37103710
static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
3711+
37113712
#ifdef USE_CUDA_GRAPH
37123713
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
37133714

@@ -3748,17 +3749,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
37483749
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
37493750
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
37503751

3752+
ggml_cuda_set_device(cuda_ctx->device);
3753+
37513754
bool use_cuda_graph = false;
37523755
bool cuda_graph_update_required = false;
37533756

37543757
// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
37553758
// we call it here instead.
37563759
#ifdef USE_CUDA_GRAPH
3757-
if (!cuda_ctx->cuda_graph) {
3758-
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
3759-
} else {
3760-
use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
3761-
}
3760+
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
37623761

37633762
if (use_cuda_graph) {
37643763
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
@@ -3774,6 +3773,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
37743773

37753774
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
37763775
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
3776+
cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
37773777
#ifndef NDEBUG
37783778
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
37793779
#endif

0 commit comments

Comments
 (0)