Skip to content

Commit 1df9b7c

Browse files
add wasm load by name for WASI-NN
1 parent a996689 commit 1df9b7c

4 files changed

Lines changed: 11 additions & 7 deletions

File tree

core/iwasm/libraries/wasi-nn/include/wasi_nn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
3030
__attribute__((import_module("wasi_nn")));
3131

3232
wasi_nn_error
33-
load_by_name(const char *name, graph *g)
33+
load_by_name(char *name, uint32_t name_len, graph *g)
3434
__attribute__((import_module("wasi_nn")));
3535

3636
/**

core/iwasm/libraries/wasi-nn/src/wasi_nn.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
697697
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
698698
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
699699
REG_NATIVE_FUNC(load, "(*ii*)i"),
700+
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
700701
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
701702
REG_NATIVE_FUNC(set_input, "(ii*)i"),
702703
REG_NATIVE_FUNC(compute, "(i)i"),

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
8585
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
8686
return runtime_error;
8787
}
88-
if (tfl_ctx->models[g].model_pointer == NULL) {
88+
if (tfl_ctx->models[g].model_pointer == NULL
89+
&& tfl_ctx->models[g].model == NULL) {
8990
NN_ERR_PRINTF("Context (model) non-initialized.");
9091
return runtime_error;
9192
}
92-
if (tfl_ctx->models[g].model == NULL) {
93-
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
94-
return runtime_error;
95-
}
9693
return success;
9794
}
9895

core/iwasm/libraries/wasi-nn/test/utils.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <stdio.h>
1111
#include <stdlib.h>
12+
#define USE_WASM_LOAD_BY_NAME 1
1213

1314
wasi_nn_error
1415
wasm_load(char *model_name, graph *g, execution_target target)
@@ -58,7 +59,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
5859
wasi_nn_error
5960
wasm_load_by_name(const char *model_name, graph *g)
6061
{
61-
wasi_nn_error res = load_by_name(model_name, g);
62+
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
6263
return res;
6364
}
6465

@@ -108,7 +109,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
108109
uint32_t num_output_tensors)
109110
{
110111
graph graph;
112+
113+
#if USE_WASM_LOAD_BY_NAME == 0
111114
if (wasm_load(model_name, &graph, target) != success) {
115+
#else
116+
if (wasm_load_by_name(model_name, &graph) != success) {
117+
#endif
112118
NN_ERR_PRINTF("Error when loading model.");
113119
exit(1);
114120
}

0 commit comments

Comments
 (0)