diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index ad1f37deb5..2c7d9ea237 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -15,6 +15,44 @@ #include #include "wasi_nn_types.h" +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +wasi_nn_error +load(graph_builder *builder, uint32_t builder_wasm_size, + graph_encoding encoding, execution_target target, graph *g) + __attribute__((import_module("wasi_ephemeral_nn"))); + +wasi_nn_error +load_by_name(char *name, uint32_t name_len, graph *g) + __attribute__((import_module("wasi_ephemeral_nn"))); + +wasi_nn_error +load_by_name_with_config(const char *name, uint32_t name_len, void *config, + uint32_t config_len, graph *g) + __attribute__((import_module("wasi_ephemeral_nn"))); +/** + * INFERENCE + * + */ + +wasi_nn_error +init_execution_context(graph g, graph_execution_context *exec_ctx) + __attribute__((import_module("wasi_ephemeral_nn"))); + +wasi_nn_error +set_input(graph_execution_context ctx, uint32_t index, tensor *tensor) + __attribute__((import_module("wasi_ephemeral_nn"))); + +wasi_nn_error +compute(graph_execution_context ctx) + __attribute__((import_module("wasi_ephemeral_nn"))); + +wasi_nn_error +get_output(graph_execution_context ctx, uint32_t index, + tensor_data output_tensor, uint32_t output_tensor_len, + uint32_t *output_tensor_size) + __attribute__((import_module("wasi_ephemeral_nn"))); +#else + /** * @brief Load an opaque sequence of bytes to use for inference. * @@ -30,7 +68,7 @@ load(graph_builder_array *builder, graph_encoding encoding, __attribute__((import_module("wasi_nn"))); wasi_nn_error -load_by_name(const char *name, graph *g) +load_by_name(char *name, uint32_t name_len, graph *g) __attribute__((import_module("wasi_nn"))); /** @@ -86,5 +124,5 @@ wasi_nn_error get_output(graph_execution_context ctx, uint32_t index, tensor_data output_tensor, uint32_t *output_tensor_size) __attribute__((import_module("wasi_nn"))); - +#endif #endif diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 4697e931b0..75f362c76f 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = { REG_NATIVE_FUNC(get_output, "(ii*i*)i"), #else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ REG_NATIVE_FUNC(load, "(*ii*)i"), + REG_NATIVE_FUNC(load_by_name, "(*i*)i"), REG_NATIVE_FUNC(init_execution_context, "(i*)i"), REG_NATIVE_FUNC(set_input, "(ii*)i"), REG_NATIVE_FUNC(compute, "(i)i"), diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index f63d57e074..517f374108 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g) NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); return runtime_error; } - if (tfl_ctx->models[g].model_pointer == NULL) { + if (tfl_ctx->models[g].model_pointer == NULL + && tfl_ctx->models[g].model == NULL) { NN_ERR_PRINTF("Context (model) non-initialized."); return runtime_error; } - if (tfl_ctx->models[g].model == NULL) { - NN_ERR_PRINTF("Context (tflite model) non-initialized."); - return runtime_error; - } return success; } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 9e43ec9854..290dccb8ed 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target) wasi_nn_error wasm_load_by_name(const char *model_name, graph *g) { - wasi_nn_error res = load_by_name(model_name, g); + wasi_nn_error res = load_by_name(model_name, strlen(model_name), g); return res; } @@ -99,7 +99,11 @@ wasi_nn_error wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, uint32_t *out_size) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + return get_output(ctx, index, (uint8_t *)out_tensor, *out_size, out_size); +#else return get_output(ctx, index, (uint8_t *)out_tensor, out_size); +#endif } float * @@ -108,7 +112,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size, uint32_t num_output_tensors) { graph graph; + +#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 if (wasm_load(model_name, &graph, target) != success) { +#else + if (wasm_load_by_name(model_name, &graph) != success) { +#endif NN_ERR_PRINTF("Error when loading model."); exit(1); }