Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/iwasm/libraries/wasi-nn/include/wasi_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(const char *name, uint32_t name_len, graph *g)
__attribute__((import_module("wasi_nn")));

/**
Expand Down
1 change: 1 addition & 0 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
Expand Down
41 changes: 18 additions & 23 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
if (tfl_ctx->models[g].model == NULL) {
Comment thread
lum1n0us marked this conversation as resolved.
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
return success;
Expand Down Expand Up @@ -472,32 +468,31 @@ deinit_backend(void *tflite_ctx)
NN_DBG_PRINTF("Freeing memory.");
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
tfl_ctx->models[i].model.reset();
if (tfl_ctx->models[i].model_pointer) {
if (tfl_ctx->delegate) {
switch (tfl_ctx->models[i].target) {
case gpu:
{
if (tfl_ctx->delegate) {
switch (tfl_ctx->models[i].target) {
case gpu:
{
#if WASM_ENABLE_WASI_NN_GPU != 0
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
#else
NN_ERR_PRINTF("GPU delegate delete but not enabled.");
NN_ERR_PRINTF("GPU delegate delete but not enabled.");
#endif
break;
}
case tpu:
{
break;
}
case tpu:
{
#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
#else
NN_ERR_PRINTF(
"External delegate delete but not enabled.");
NN_ERR_PRINTF("External delegate delete but not enabled.");
#endif
break;
}
default:
break;
break;
}
default:
break;
}
}
if (tfl_ctx->models[i].model_pointer) {
wasm_runtime_free(tfl_ctx->models[i].model_pointer);
}
tfl_ctx->models[i].model_pointer = NULL;
Expand Down
5 changes: 3 additions & 2 deletions core/iwasm/libraries/wasi-nn/test/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, g);
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
return res;
}

Expand Down Expand Up @@ -108,7 +108,8 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors)
{
graph graph;
if (wasm_load(model_name, &graph, target) != success) {

if (wasm_load_by_name(model_name, &graph) != success) {
NN_ERR_PRINTF("Error when loading model.");
exit(1);
}
Expand Down
Loading