Skip to content

Commit cd94d20

Browse files
am17anSascha
authored andcommitted
CUDA: also store node->src ne/nb for graph equality (ggml-org#21736)
1 parent a045c04 commit cd94d20

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,9 @@ struct ggml_cuda_graph {
11851185
bool warmup_complete = false;
11861186
struct node_properties {
11871187
ggml_tensor node;
1188-
void * node_src_data_ptrs[GGML_MAX_SRC];
1188+
void * node_src_data_ptrs[GGML_MAX_SRC];
1189+
int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
1190+
size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
11891191
};
11901192
std::vector<node_properties> node_props;
11911193

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,16 +3083,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
30833083
ggml_cuda_graph::node_properties prop = {};
30843084
memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor));
30853085

3086-
// if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see:
3087-
// https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188
30883086
for (int j = 0; j < GGML_MAX_SRC; ++j) {
3089-
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr;
3087+
if (cgraph->nodes[i]->src[j]) {
3088+
prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data;
3089+
memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j]));
3090+
memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j]));
3091+
}
30903092
}
30913093

3092-
if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
3094+
if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) {
3095+
graph->node_props[i] = prop;
30933096
res = true;
30943097
}
3095-
graph->node_props[i] = prop;
30963098
}
30973099

30983100
return res;

0 commit comments

Comments
 (0)