Skip to content

Commit 908a9e5

Browse files
authored
CUDA: disable cuda graph when using n-cpu-moe (ggml-org#18593)
* CUDA: disable cuda graph when using n-cpu-moe * call ggml_cuda_set_device
1 parent 5126c41 commit 908a9e5

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3696,6 +3696,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
36963696
}
36973697

36983698
static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
3699+
36993700
#ifdef USE_CUDA_GRAPH
37003701
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
37013702

@@ -3736,17 +3737,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
37363737
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
37373738
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
37383739

3740+
ggml_cuda_set_device(cuda_ctx->device);
3741+
37393742
bool use_cuda_graph = false;
37403743
bool cuda_graph_update_required = false;
37413744

37423745
// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
37433746
// we call it here instead.
37443747
#ifdef USE_CUDA_GRAPH
3745-
if (!cuda_ctx->cuda_graph) {
3746-
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
3747-
} else {
3748-
use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
3749-
}
3748+
use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
37503749

37513750
if (use_cuda_graph) {
37523751
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
@@ -3762,6 +3761,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
37623761

37633762
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
37643763
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
3764+
cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
37653765
#ifndef NDEBUG
37663766
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
37673767
#endif

0 commit comments

Comments
 (0)