Skip to content

Commit a29e4c0

Browse files
authored
CUDA: also store node->src ne/nb for graph equality (ggml-org#21736)
1 parent b136b62 commit a29e4c0

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,9 @@ struct ggml_cuda_graph {
11851185
bool warmup_complete = false;
11861186
struct node_properties {
11871187
ggml_tensor node;
1188-
void * node_src_data_ptrs[GGML_MAX_SRC];
1188+
void * node_src_data_ptrs[GGML_MAX_SRC];
1189+
int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
1190+
size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
11891191
};
11901192
std::vector<node_properties> node_props;
11911193

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3070,16 +3070,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
30703070
ggml_cuda_graph::node_properties prop = {};
30713071
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
30723072

3073-
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
3074-
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
30753073
for (int j = 0; j < GGML_MAX_SRC; ++j) {
3076-
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
3074+
if (cgraph->nodes[i]->src[j]) {
3075+
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data;
3076+
memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j]));
3077+
memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j]));
3078+
}
30773079
}
30783080

3079-
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
3081+
if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
3082+
graph->node_props[i] = prop;
30803083
res = true;
30813084
}
3082-
graph->node_props[i] = prop;
30833085
}
30843086

30853087
return res;

0 commit comments

Comments
 (0)