Skip to content

Commit 441cc19

Browse files
authored
CUDA: also store node->src->data ptrs for equality check (ggml-org#21635)
* CUDA: also store node->src->data ptrs for equality check * address review comments
1 parent 9574599 commit 441cc19

2 files changed

Lines changed: 19 additions & 8 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,11 @@ struct ggml_cuda_graph {
11731173
std::vector<cudaGraphNode_t> nodes;
11741174
bool disable_due_to_gpu_arch = false;
11751175
bool warmup_complete = false;
1176-
std::vector<ggml_tensor> nodes_copy;
1176+
struct node_properties {
1177+
ggml_tensor node;
1178+
void * node_src_data_ptrs[GGML_MAX_SRC];
1179+
};
1180+
std::vector<node_properties> node_props;
11771181

11781182
bool is_enabled() const {
11791183
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,18 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
29792979
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
29802980

29812981
// Check if the graph size has changed
2982-
if ((int)graph->nodes_copy.size() != cgraph->n_nodes) {
2982+
if ((int)graph->node_props.size() != cgraph->n_nodes) {
29832983
res = true;
2984-
graph->nodes_copy.resize(cgraph->n_nodes);
2984+
graph->node_props.resize(cgraph->n_nodes);
29852985
}
29862986

29872987
for (int i = 0; i < cgraph->n_nodes; i++) {
2988-
if (!res) {
2989-
if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
2990-
res = true;
2991-
}
2988+
ggml_cuda_graph::node_properties prop = {};
2989+
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
2990+
2991+
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
2992+
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
2993+
for (int j = 0; j < GGML_MAX_SRC; ++j) {
2994+
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
2995+
}
2996+
2997+
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
2998+
res = true;
29922999
}
2993-
memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor));
3000+
graph->node_props[i] = prop;
29943001
}
29953002

29963003
return res;

0 commit comments

Comments
 (0)