@@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
397397 return ret ;
398398}
399399
400+ static wasi_nn_error
401+ ensure_backend (wasm_module_inst_t instance , graph_encoding encoding ,
402+ WASINNContext * * wasi_nn_ctx_ptr )
403+ {
404+ wasi_nn_error res ;
405+
406+ graph_encoding loaded_backend = autodetect ;
407+ if (!detect_and_load_backend (encoding , & loaded_backend )) {
408+ res = invalid_encoding ;
409+ NN_ERR_PRINTF ("load backend failed" );
410+ goto fail ;
411+ }
412+
413+ WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
414+ if (wasi_nn_ctx -> is_backend_ctx_initialized ) {
415+ if (wasi_nn_ctx -> backend != loaded_backend ) {
416+ res = unsupported_operation ;
417+ goto fail ;
418+ }
419+ }
420+ else {
421+ wasi_nn_ctx -> backend = loaded_backend ;
422+
423+ /* init() the backend */
424+ call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
425+ & wasi_nn_ctx -> backend_ctx );
426+ if (res != success )
427+ goto fail ;
428+
429+ wasi_nn_ctx -> is_backend_ctx_initialized = true;
430+ }
431+ * wasi_nn_ctx_ptr = wasi_nn_ctx ;
432+ return success ;
433+ fail :
434+ return res ;
435+ }
436+
400437/* WASI-NN implementation */
401438
402439#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@@ -410,14 +447,15 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
410447 graph_encoding encoding , execution_target target , graph * g )
411448#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
412449{
450+ wasi_nn_error res ;
451+
413452 NN_DBG_PRINTF ("[WASI NN] LOAD [encoding=%d, target=%d]..." , encoding ,
414453 target );
415454
416455 wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
417456 if (!instance )
418457 return runtime_error ;
419458
420- wasi_nn_error res ;
421459 graph_builder_array builder_native = { 0 };
422460#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
423461 if (success
@@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
438476 goto fail ;
439477 }
440478
441- graph_encoding loaded_backend = autodetect ;
442- if (!detect_and_load_backend (encoding , & loaded_backend )) {
443- res = invalid_encoding ;
444- NN_ERR_PRINTF ("load backend failed" );
445- goto fail ;
446- }
447-
448- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
449- wasi_nn_ctx -> backend = loaded_backend ;
450-
451- /* init() the backend */
452- call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
453- & wasi_nn_ctx -> backend_ctx );
479+ WASINNContext * wasi_nn_ctx ;
480+ res = ensure_backend (instance , encoding , & wasi_nn_ctx );
454481 if (res != success )
455482 goto fail ;
456483
@@ -473,6 +500,8 @@ wasi_nn_error
473500wasi_nn_load_by_name (wasm_exec_env_t exec_env , char * name , uint32_t name_len ,
474501 graph * g )
475502{
503+ wasi_nn_error res ;
504+
476505 wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
477506 if (!instance ) {
478507 return runtime_error ;
@@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
496525
497526 NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME %s..." , name );
498527
499- graph_encoding loaded_backend = autodetect ;
500- if (!detect_and_load_backend (autodetect , & loaded_backend )) {
501- NN_ERR_PRINTF ("load backend failed" );
502- return invalid_encoding ;
503- }
504-
505- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
506- wasi_nn_ctx -> backend = loaded_backend ;
507-
508- wasi_nn_error res ;
509- /* init() the backend */
510- call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
511- & wasi_nn_ctx -> backend_ctx );
528+ WASINNContext * wasi_nn_ctx ;
529+ res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
512530 if (res != success )
513531 return res ;
514532
@@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
526544 int32_t name_len , char * config ,
527545 int32_t config_len , graph * g )
528546{
547+ wasi_nn_error res ;
548+
529549 wasm_module_inst_t instance = wasm_runtime_get_module_inst (exec_env );
530550 if (!instance ) {
531551 return runtime_error ;
@@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
554574
555575 NN_DBG_PRINTF ("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s..." , name , config );
556576
557- graph_encoding loaded_backend = autodetect ;
558- if (!detect_and_load_backend (autodetect , & loaded_backend )) {
559- NN_ERR_PRINTF ("load backend failed" );
560- return invalid_encoding ;
561- }
562-
563- WASINNContext * wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx (instance );
564- wasi_nn_ctx -> backend = loaded_backend ;
565-
566- wasi_nn_error res ;
567- /* init() the backend */
568- call_wasi_nn_func (wasi_nn_ctx -> backend , init , res ,
569- & wasi_nn_ctx -> backend_ctx );
577+ WASINNContext * wasi_nn_ctx ;
578+ res = ensure_backend (instance , autodetect , & wasi_nn_ctx );
570579 if (res != success )
571580 return res ;
572581
0 commit comments