3737#pragma GCC diagnostic push
3838#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
3939
40- enum ggml_status ov_graph_compute (ggml_cgraph * cgraph) {
40+ enum ggml_status ov_graph_compute (ggml_cgraph * cgraph, ggml_backend_t backend) {
41+ ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context ;
4142 try {
4243 if (getenv (" GGML_OPENVINO_DUMP_CGRAPH" )) {
4344 std::string filename = " cgraph_ov.txt" ;
4445 GgmlOvDecoder::dump_cgraph (cgraph, filename);
4546 }
4647
47- // Use device from singleton (initialized during backend init)
48- const auto & device = ggml_openvino_get_device_name ();
4948 const auto is_static = ggml_openvino_is_npu ();
50- bool stateful = false ;
49+
50+ if (ctx->ov_runtime_context == nullptr ) {
51+ ctx->ov_runtime_context = std::make_shared<ov_runtime_context>();
52+ }
53+ std::shared_ptr<ov_runtime_context> r_ctx = std::static_pointer_cast<ov_runtime_context>(ctx->ov_runtime_context );
54+ r_ctx->device = ggml_openvino_get_device_name ();
55+ r_ctx->stateful = false ;
5156 if (getenv (" GGML_OPENVINO_STATEFUL_EXECUTION" ) && !is_static) {
52- stateful = true ;
57+ r_ctx-> stateful = true ;
5358 }
5459
55- return is_static ? ov_graph_compute_static (cgraph) : ov_graph_compute_dynamic (cgraph, device, stateful );
60+ return is_static ? ov_graph_compute_static (cgraph, r_ctx ) : ov_graph_compute_dynamic (cgraph, r_ctx );
5661 } catch (const ov::Exception & e) {
5762 GGML_LOG_ERROR (" GGML OpenVINO backend ov::Exception: %s\n " , e.what ());
5863 return GGML_STATUS_FAILED;
@@ -65,24 +70,19 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) {
6570 }
6671}
6772
68- enum ggml_status ov_graph_compute_dynamic (ggml_cgraph * cgraph, const std::string & device, bool stateful ) {
73+ enum ggml_status ov_graph_compute_dynamic (ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx ) {
6974 auto & core = ov_singleton_core ();
7075 const auto & config = ggml_openvino_get_compile_config ();
76+ auto device = r_ctx->device ;
77+ bool stateful = r_ctx->stateful ;
7178 static auto is_static = false ;
72- static size_t stateful_kv_size = 0 ;
7379
7480 if (is_naive (cgraph)) {
7581 return naive_compute (cgraph, core, device, config);
7682 }
7783
7884 auto start_time = ggml_time_us ();
7985
80- static std::mutex cache_mutex;
81- static std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;
82- static std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
83- static std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;
84- static std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_output_names_cache;
85-
8686 std::shared_ptr<GgmlOvDecoder> ggml_decoder;
8787 std::shared_ptr<ov::InferRequest> infer_request;
8888 ModelParams m_params;
@@ -98,11 +98,11 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
9898 int64_t infer_end_time;
9999
100100 {
101- std::lock_guard<std::mutex> lock (cache_mutex);
101+ std::lock_guard<std::mutex> lock (r_ctx-> cache_mutex );
102102
103- auto it = decoder_cache.find (key);
103+ auto it = r_ctx-> decoder_cache .find (key);
104104
105- cache_hit = it != decoder_cache.end ();
105+ cache_hit = it != r_ctx-> decoder_cache .end ();
106106 ModelParams old_m_params;
107107 if (cache_hit) {
108108 ggml_decoder = it->second ;
@@ -118,17 +118,17 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
118118 ggml_decoder->update_io (cgraph);
119119 }
120120 ggml_decoder->add_extra_inputs ();
121- infer_request = infer_request_cache.at (key);
121+ infer_request = r_ctx-> infer_request_cache .at (key);
122122
123123 if (stateful) {
124124 const auto * inp_pos = get_inp_pos_tensor (cgraph);
125125 int32_t * pos_data = (int32_t *) inp_pos->data ;
126126 auto pos_shape = ggml_decoder->get_shape (inp_pos);
127127 if (pos_data[0 ] == 0 ) {
128128 infer_request->reset_state ();
129- stateful_kv_size = pos_shape[3 ];
130- } else if (stateful_kv_size == static_cast <size_t >(pos_data[0 ])) {
131- stateful_kv_size += pos_shape[3 ];
129+ r_ctx-> stateful_kv_size = pos_shape[3 ];
130+ } else if (r_ctx-> stateful_kv_size == static_cast <size_t >(pos_data[0 ])) {
131+ r_ctx-> stateful_kv_size += pos_shape[3 ];
132132 } else {
133133 auto states = infer_request->query_state ();
134134 for (auto state : states) {
@@ -138,15 +138,15 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
138138 ov::Tensor new_state_tensor (state_tensor, begin, end);
139139 state.set_state (new_state_tensor);
140140 }
141- stateful_kv_size = pos_data[0 ] + 1 ;
141+ r_ctx-> stateful_kv_size = pos_data[0 ] + 1 ;
142142 }
143143 }
144144
145145 decoder_end_time = ggml_time_us ();
146146 conversion_end_time = decoder_end_time;
147147 compile_end_time = decoder_end_time;
148148 } else {
149- infer_request_cache.erase (key);
149+ r_ctx-> infer_request_cache .erase (key);
150150
151151 std::shared_ptr<ov::Model> model;
152152 auto model_weights = GgmlOvDecoder::create_weight_nodes (cgraph);
@@ -176,8 +176,8 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
176176 }
177177 compile_end_time = ggml_time_us ();
178178 infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request ());
179- infer_request_cache[key] = infer_request;
180- decoder_cache[key] = ggml_decoder;
179+ r_ctx-> infer_request_cache [key] = infer_request;
180+ r_ctx-> decoder_cache [key] = ggml_decoder;
181181
182182 std::vector<std::string> ov_input_names;
183183 std::vector<std::string> ov_output_names;
@@ -187,12 +187,16 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
187187 for (const auto & ov_output : model->get_results ()) {
188188 ov_output_names.push_back (ov_output->get_friendly_name ());
189189 }
190- ov_input_names_cache[key] = std::move (ov_input_names);
191- ov_output_names_cache[key] = std::move (ov_output_names);
190+ r_ctx->ov_input_names_cache [key] = std::move (ov_input_names);
191+ r_ctx->ov_output_names_cache [key] = std::move (ov_output_names);
192+
193+ if (stateful) {
194+ r_ctx->stateful_kv_size = 0 ;
195+ }
192196 }
193197
194- auto ov_input_names = ov_input_names_cache[key];
195- auto ov_output_names = ov_output_names_cache[key];
198+ auto ov_input_names = r_ctx-> ov_input_names_cache [key];
199+ auto ov_output_names = r_ctx-> ov_output_names_cache [key];
196200
197201 for (size_t i = 0 ; i < ov_input_names.size (); i++) {
198202 auto param_name = ov_input_names[i];
@@ -233,7 +237,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
233237 return GGML_STATUS_SUCCESS;
234238}
235239
236- enum ggml_status ov_graph_compute_static (ggml_cgraph * cgraph) {
240+ enum ggml_status ov_graph_compute_static (ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx ) {
237241 auto & core = ov_singleton_core ();
238242
239243 auto get_prefill_chunk_size = [] {
@@ -256,13 +260,6 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
256260
257261 auto start_time = ggml_time_us ();
258262
259- static std::mutex cache_mutex;
260- static std::unordered_map<graph_key, std::shared_ptr<GgmlOvDecoder>, graph_key_hash> decoder_cache;
261- static std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache;
262- static std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill;
263- static std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache;
264- static std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_output_names_cache;
265-
266263 std::shared_ptr<GgmlOvDecoder> ggml_decoder;
267264 std::shared_ptr<ov::InferRequest> infer_request;
268265 ModelParams m_params;
@@ -280,11 +277,11 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
280277 int64_t infer_end_time;
281278
282279 {
283- std::lock_guard<std::mutex> lock (cache_mutex);
280+ std::lock_guard<std::mutex> lock (r_ctx-> cache_mutex );
284281
285- auto it = decoder_cache.find (key);
282+ auto it = r_ctx-> decoder_cache .find (key);
286283
287- cache_hit = it != decoder_cache.end ();
284+ cache_hit = it != r_ctx-> decoder_cache .end ();
288285 ModelParams old_m_params;
289286 if (cache_hit) {
290287 ggml_decoder = it->second ;
@@ -301,14 +298,14 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
301298 ggml_decoder->update_io (cgraph);
302299 }
303300 ggml_decoder->add_extra_inputs ();
304- infer_request = is_prefill ? infer_request_cache_prefill.at (key) : infer_request_cache.at (key);
301+ infer_request = is_prefill ? r_ctx-> infer_request_cache_prefill .at (key) : r_ctx-> infer_request_cache .at (key);
305302
306303 decoder_end_time = ggml_time_us ();
307304 conversion_end_time = decoder_end_time;
308305 compile_end_time = decoder_end_time;
309306 } else {
310- infer_request_cache.erase (key);
311- infer_request_cache_prefill.erase (key);
307+ r_ctx-> infer_request_cache .erase (key);
308+ r_ctx-> infer_request_cache_prefill .erase (key);
312309
313310 std::shared_ptr<ov::Model> model;
314311 auto model_weights = GgmlOvDecoder::create_weight_nodes (cgraph);
@@ -348,15 +345,15 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
348345 compiled_model_decode = core.compile_model (model_decode, device, config);
349346 }
350347
351- infer_request_cache_prefill[key] =
348+ r_ctx-> infer_request_cache_prefill [key] =
352349 std::make_shared<ov::InferRequest>(compiled_model_prefill.create_infer_request ());
353- infer_request_cache[key] = std::make_shared<ov::InferRequest>(compiled_model_decode.create_infer_request ());
350+ r_ctx-> infer_request_cache [key] = std::make_shared<ov::InferRequest>(compiled_model_decode.create_infer_request ());
354351 compile_end_time = ggml_time_us ();
355352
356353 model = is_prefill ? model_prefill : model_decode;
357354 ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;
358- infer_request = is_prefill ? infer_request_cache_prefill[key] : infer_request_cache[key];
359- decoder_cache[key] = ggml_decoder;
355+ infer_request = is_prefill ? r_ctx-> infer_request_cache_prefill [key] : r_ctx-> infer_request_cache [key];
356+ r_ctx-> decoder_cache [key] = ggml_decoder;
360357
361358 std::vector<std::string> ov_input_names;
362359 std::vector<std::string> ov_output_names;
@@ -366,13 +363,13 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
366363 for (const auto & ov_output : model->get_results ()) {
367364 ov_output_names.push_back (ov_output->get_friendly_name ());
368365 }
369- ov_input_names_cache[key] = std::move (ov_input_names);
370- ov_output_names_cache[key] = std::move (ov_output_names);
366+ r_ctx-> ov_input_names_cache [key] = std::move (ov_input_names);
367+ r_ctx-> ov_output_names_cache [key] = std::move (ov_output_names);
371368 }
372369 }
373370
374- auto ov_input_names = ov_input_names_cache[key];
375- auto ov_output_names = ov_output_names_cache[key];
371+ auto ov_input_names = r_ctx-> ov_input_names_cache [key];
372+ auto ov_output_names = r_ctx-> ov_output_names_cache [key];
376373
377374 if (is_prefill) {
378375 auto inp_len = inp_pos->ne [0 ];
0 commit comments