Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/iwasm/libraries/wasi-nn/include/wasi_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(const char *name, uint32_t name_len, graph *g)
__attribute__((import_module("wasi_nn")));

/**
Expand Down
1 change: 1 addition & 0 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
Expand Down
7 changes: 2 additions & 5 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Comment thread
lum1n0us marked this conversation as resolved.
Outdated
NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
if (tfl_ctx->models[g].model == NULL) {
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
return runtime_error;
}
return success;
}

Expand Down
7 changes: 6 additions & 1 deletion core/iwasm/libraries/wasi-nn/test/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, g);
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
return res;
}

Expand Down Expand Up @@ -108,7 +108,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors)
{
graph graph;

#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
if (wasm_load(model_name, &graph, target) != success) {
#else
if (wasm_load_by_name(model_name, &graph) != success) {
#endif
Comment thread
lum1n0us marked this conversation as resolved.
Outdated
NN_ERR_PRINTF("Error when loading model.");
exit(1);
}
Expand Down
Loading