Skip to content

Commit aa1ff77

Browse files
add load_by_name in wasi-nn (#4298)
1 parent 2a30386 commit aa1ff77

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

core/iwasm/libraries/wasi-nn/include/wasi_nn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
3030
__attribute__((import_module("wasi_nn")));
3131

3232
wasi_nn_error
33-
load_by_name(const char *name, graph *g)
33+
load_by_name(const char *name, uint32_t name_len, graph *g)
3434
__attribute__((import_module("wasi_nn")));
3535

3636
/**

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
697697
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
698698
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
699699
REG_NATIVE_FUNC(load, "(*ii*)i"),
700+
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
700701
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
701702
REG_NATIVE_FUNC(set_input, "(ii*)i"),
702703
REG_NATIVE_FUNC(compute, "(i)i"),

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,8 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
8585
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
8686
return runtime_error;
8787
}
88-
if (tfl_ctx->models[g].model_pointer == NULL) {
89-
NN_ERR_PRINTF("Context (model) non-initialized.");
90-
return runtime_error;
91-
}
9288
if (tfl_ctx->models[g].model == NULL) {
93-
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
89+
NN_ERR_PRINTF("Context (model) non-initialized.");
9490
return runtime_error;
9591
}
9692
return success;
@@ -472,32 +468,31 @@ deinit_backend(void *tflite_ctx)
472468
NN_DBG_PRINTF("Freeing memory.");
473469
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
474470
tfl_ctx->models[i].model.reset();
475-
if (tfl_ctx->models[i].model_pointer) {
476-
if (tfl_ctx->delegate) {
477-
switch (tfl_ctx->models[i].target) {
478-
case gpu:
479-
{
471+
if (tfl_ctx->delegate) {
472+
switch (tfl_ctx->models[i].target) {
473+
case gpu:
474+
{
480475
#if WASM_ENABLE_WASI_NN_GPU != 0
481-
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
476+
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
482477
#else
483-
NN_ERR_PRINTF("GPU delegate delete but not enabled.");
478+
NN_ERR_PRINTF("GPU delegate delete but not enabled.");
484479
#endif
485-
break;
486-
}
487-
case tpu:
488-
{
480+
break;
481+
}
482+
case tpu:
483+
{
489484
#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
490-
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
485+
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
491486
#else
492-
NN_ERR_PRINTF(
493-
"External delegate delete but not enabled.");
487+
NN_ERR_PRINTF("External delegate delete but not enabled.");
494488
#endif
495-
break;
496-
}
497-
default:
498-
break;
489+
break;
499490
}
491+
default:
492+
break;
500493
}
494+
}
495+
if (tfl_ctx->models[i].model_pointer) {
501496
wasm_runtime_free(tfl_ctx->models[i].model_pointer);
502497
}
503498
tfl_ctx->models[i].model_pointer = NULL;

core/iwasm/libraries/wasi-nn/test/utils.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
5858
wasi_nn_error
5959
wasm_load_by_name(const char *model_name, graph *g)
6060
{
61-
wasi_nn_error res = load_by_name(model_name, g);
61+
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
6262
return res;
6363
}
6464

@@ -108,7 +108,8 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
108108
uint32_t num_output_tensors)
109109
{
110110
graph graph;
111-
if (wasm_load(model_name, &graph, target) != success) {
111+
112+
if (wasm_load_by_name(model_name, &graph) != success) {
112113
NN_ERR_PRINTF("Error when loading model.");
113114
exit(1);
114115
}

0 commit comments

Comments
 (0)