8282#include < cstdlib>
8383#include < string>
8484#include < vector>
85- #include < unordered_set>
8685
8786static_assert (sizeof (half) == sizeof (ggml_fp16_t ), " wrong fp16 size" );
8887
@@ -3041,48 +3040,12 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
30413040 return cgraph->nodes [0 ];
30423041}
30433042
3044- // compute a FNV-1a over all nodes and srcs which should change when a cuda graph cannot be reused
3045- static uint64_t ggml_cuda_graph_hash (ggml_cgraph * cgraph) {
3046- uint64_t h = 0xcbf29ce484222325ULL ;
3047- constexpr uint64_t prime = 0x100000001b3ULL ;
3048-
3049- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3050- const ggml_tensor * node = cgraph->nodes [i];
3051-
3052- h ^= (uintptr_t )node->data ;
3053- h *= prime;
3054-
3055- for (int s = 0 ; s < GGML_MAX_SRC; s++) {
3056- if (node->src [s]) {
3057- h ^= (uintptr_t )node->src [s]->data ;
3058- h *= prime;
3059- }
3060- }
3061-
3062- // Hash first 16 bytes of op_params
3063- const uint64_t * params = (const uint64_t *)node->op_params ;
3064- h ^= params[0 ];
3065- h *= prime;
3066- h ^= params[1 ];
3067- h *= prime;
3068- }
3069-
3070- return h;
3071- }
3072-
30733043static bool ggml_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
30743044 bool res = false ;
30753045
30763046 const void * graph_key = ggml_cuda_graph_get_key (cgraph);
30773047 ggml_cuda_graph * graph = cuda_ctx->cuda_graph (graph_key);
30783048
3079- if (graph->props_stable >= 2 && graph->props .size () == (size_t )cgraph->n_nodes ) {
3080- if (ggml_cuda_graph_hash (cgraph) == graph->last_props_hash ) {
3081- return false ;
3082- }
3083- graph->props_stable = 0 ;
3084- }
3085-
30863049 // Check if the graph size has changed
30873050 if (graph->props .size () != (size_t )cgraph->n_nodes ) {
30883051 res = true ;
@@ -3091,12 +3054,16 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
30913054
30923055 // Loop over nodes in GGML graph to determine if CUDA graph update is required
30933056 // and store properties to allow this comparison for the next token
3094- std::unordered_set<ggml_tensor *> seen_node;
3095- std::vector<ggml_tensor *> srcs_extra;
3057+
3058+ const int32_t flag_seen = GGML_TENSOR_FLAG_UNUSED;
3059+
30963060 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3097- bool props_match = true ;
3061+ cgraph->nodes [i]->flags |= flag_seen;
3062+ }
30983063
3099- seen_node.insert (cgraph->nodes [i]);
3064+ size_t extra_idx = 0 ;
3065+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3066+ bool props_match = true ;
31003067
31013068 if (!res) {
31023069 props_match = ggml_cuda_graph_node_properties_match (cgraph->nodes [i], &graph->props [i]);
@@ -3108,35 +3075,30 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
31083075
31093076 for (int src_idx = 0 ; src_idx < GGML_MAX_SRC; ++src_idx) {
31103077 ggml_tensor * src = cgraph->nodes [i]->src [src_idx];
3111- if (src && seen_node.find (src) == seen_node.end ()) {
3112- srcs_extra.push_back (src);
3078+ if (src && !(src->flags & flag_seen)) {
3079+ if (extra_idx >= graph->extra .size ()) {
3080+ graph->extra .push_back ({});
3081+ res = true ;
3082+ }
3083+
3084+ if (!res) {
3085+ if (!ggml_cuda_graph_node_properties_match (src, &graph->extra [extra_idx])) {
3086+ res = true ;
3087+ }
3088+ }
3089+ ggml_cuda_graph_node_set_properties (&graph->extra [extra_idx], src);
3090+ extra_idx++;
31133091 }
31143092 }
31153093 }
31163094
3117- if (graph->extra .size () != ( size_t ) srcs_extra. size () ) {
3095+ if (graph->extra .size () != extra_idx ) {
31183096 res = true ;
3119- graph->extra .resize (srcs_extra.size ());
3120- }
3121-
3122- for (size_t i = 0 ; i < srcs_extra.size (); ++i) {
3123- bool props_match = true ;
3124-
3125- if (!res) {
3126- props_match = ggml_cuda_graph_node_properties_match (srcs_extra[i], &graph->extra [i]);
3127- }
3128-
3129- if (!props_match) {
3130- res = true ;
3131- }
3132- ggml_cuda_graph_node_set_properties (&graph->extra [i], srcs_extra[i]);
3097+ graph->extra .resize (extra_idx);
31333098 }
31343099
3135- if (!res) {
3136- graph->props_stable ++;
3137- graph->last_props_hash = ggml_cuda_graph_hash (cgraph);
3138- } else {
3139- graph->props_stable = 0 ;
3100+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3101+ cgraph->nodes [i]->flags &= ~flag_seen;
31403102 }
31413103
31423104 return res;
0 commit comments