Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions core/iwasm/libraries/wasi-nn/include/wasi_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,44 @@
#include <stdint.h>
#include "wasi_nn_types.h"

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ayakoakasaka and @lum1n0us ,
For wasi_nn and wasi_ephemeral_nn, different APIs need to be defined in the header file.
However, currently using wasi_ephemeral_nn, I find that I get an error when I calling set_input, ->Content is inconsistent when passed from wasm to native.
So can we divide it into two PRs, first support load_by_name in wasi-nn, and then implement support for wasi_ephemeral_nn

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. As a workaround.

wasi_nn_error
load(graph_builder *builder, uint32_t builder_wasm_size,
graph_encoding encoding, execution_target target, graph *g)
__attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name(char *name, uint32_t name_len, graph *g)
__attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name_with_config(const char *name, uint32_t name_len, void *config,
uint32_t config_len, graph *g)
__attribute__((import_module("wasi_ephemeral_nn")));
/**
* INFERENCE
*
*/

wasi_nn_error
init_execution_context(graph g, graph_execution_context *exec_ctx)
__attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
__attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
compute(graph_execution_context ctx)
__attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t output_tensor_len,
uint32_t *output_tensor_size)
__attribute__((import_module("wasi_ephemeral_nn")));
#else

/**
* @brief Load an opaque sequence of bytes to use for inference.
*
Expand All @@ -30,7 +68,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug fix?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its a bugfix,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lum1n0us
Where can I find the function prototypes for wasi_ephemeral_nn?
For functions like get_output, the signatures are different, so simply replacing wasi_nn with wasi_ephemeral_nn mechanically doesn't seem to work.

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
REG_NATIVE_FUNC(load, "(*iii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(load_by_name_with_config, "(*i*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii**)i"),
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

Even when referring to the wasi-nn specification, the signatures declared there don’t appear to match what is used for wasi_ephemeral_nn
https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L59-L86

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HongxiaWangSSSS I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side and there will be two sets for wasi_ephemeral_nn and wasi_nn. wasi_nn will be deprecated in the future.

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
     execution_target target, graph *g)
    __attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name(const char *name, uint32_t len, graph *g)
    __attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name_with_config(const char *name, uint32_t name_len, void *config, uint32_t config_len, graph *g)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
init_execution_context(graph g, graph_execution_context *exec_ctx)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
compute(graph_execution_context ctx)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t *output_tensor_size)     __attribute__((import_module("wasi_ephemeral_nn")));;

#else

wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
     execution_target target, graph *g)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
init_execution_context(graph g, graph_execution_context *ctx)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t *output_tensor_size)
    __attribute__((import_module("wasi_nn")));
#endif

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a mismatch in the function signature?

REG_NATIVE_FUNC(get_output, "(ii*i*)i"),

and

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size) attribute((import_module("wasi_ephemeral_nn")));

Copy link
Copy Markdown
Contributor

@lum1n0us lum1n0us May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES. there should be two versions, one for wasi_ephemeral_nn, another for wasi_nn. please refer to:

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
             uint32_t builder_wasm_size, graph_encoding encoding,
             execution_target target, graph *g)
#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
             graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */


#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                   uint32_t index, tensor_data output_tensor,
                   uint32_t output_tensor_len, uint32_t *output_tensor_size)
#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                   uint32_t index, tensor_data output_tensor,
                   uint32_t *output_tensor_size)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, we need to define it at wasi_nn.h

wasi_nn_error
 get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t output_tensor_len, uint32_t *output_tensor_size)     __attribute__((import_module("wasi_ephemeral_nn")));

But the backend definition doesn't look like it matches (signature might be ok)

__attribute__((visibility("default"))) wasi_nn_error
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)

Copy link
Copy Markdown
Contributor

@lum1n0us lum1n0us May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I am suggestion this #4267 (comment) as the new content of wasi_nn.h. Plus, #4267 (comment).

__attribute__((import_module("wasi_nn")));

/**
Expand Down Expand Up @@ -86,5 +124,5 @@ wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
__attribute__((import_module("wasi_nn")));

#endif
#endif
1 change: 1 addition & 0 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use -DWASM_ENABLE_WASI_EPHEMERAL_NN=1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at first, i just built with -DWASM_ENABLE_WASI_NN=1.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If use -DWAMR_BUILD_WASI_EPHEMERAL_NN=1 during compilation, you will be able to use the set of APIs, including load_by_name(). There is no need to change this line.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If -DWAMR_BUILD_WASI_EPHEMERAL_NN=1 must be added , I think there is no need to add this line.
Currently, it is possible to use the default wasi-nn instead of wasi_ephemeral_nn.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Please do it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, does that mean that the ephemeral version is meant to be compatible with Rust (especially WasmEdge), whereas the non-ephemeral one doesn't need to be?
If that's the case, wouldn't it make sense to add load_by_name to the wasi_nn.h header, which is a straightforward C interpretation of the witx specification?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. load_by_name() needs a new signature.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, load_by_name() and wasm_load_by_name() should only exist when WASM_ENABLE_WASI_EPHEMERAL_NN is set to 1. If you transform wasm with the flag WASM_ENABLE_WASI_EPHEMERAL_NN=1, you will get wasm with import requirements from wasi_ephemeral_nn, which offers better performance.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#4267 (comment)
said this is for C langage, so why not keep it for non-wasi_ephemeral_nn to provide the same performance ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, wasi_ephemeral_nn is legacy and should be deprecated. If there are C APIs, they should follow the Rust API's design to avoid unnecessary changes for the runtime.

REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
Expand Down
7 changes: 2 additions & 5 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required to validate TFLitesContext.models[g] for both cases, using load() and load_by_name(). It will not be acceptable if the change disables one of these cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether it is load or load_by_name, the check of models[g].model_pointerdoes not seem to be necessary, just make sure the models[g].model is not NULL maybe is enough.
Do you have any idea?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this operation, the models[g].model_pointer's connect has been saved in models[g].model.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/tensorflow/tensorflow/blob/02896e880298894beedc71f8666f6949ad0e174f/tensorflow/compiler/mlir/lite/core/model_builder_base.h#L169-#L184

  /// Builds a model based on a pre-loaded flatbuffer.
  /// Caller retains ownership of the buffer and should keep it alive until
  /// the returned object is destroyed. Caller also retains ownership of
  /// `error_reporter` and must ensure its lifetime is longer than the
  /// FlatBufferModelBase instance.
  /// Returns a nullptr in case of failure.
  /// NOTE: this does NOT validate the buffer so it should NOT be called on
  /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
  static std::unique_ptr<T> BuildFromBuffer(
      const char* caller_owned_buffer, size_t buffer_size,
      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
    error_reporter = ValidateErrorReporter(error_reporter);
    std::unique_ptr<Allocation> allocation(
        new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
    return BuildFromAllocation(std::move(allocation), error_reporter);
  }

If I understand correctly, model_pointer acts as a pre-allocated buffer, and its ownership is still held by the caller, in our case, tfl_ctx. Meanwhile, model holds the ownership of the result from tflite::FlatBufferModel::BuildFromBuffer(). Therefore, both are required.

NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
if (tfl_ctx->models[g].model == NULL) {
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
return runtime_error;
}
return success;
}

Expand Down
11 changes: 10 additions & 1 deletion core/iwasm/libraries/wasi-nn/test/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, g);
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better be

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
    wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
    return res;
}
#endif 

return res;
}

Expand Down Expand Up @@ -99,7 +99,11 @@ wasi_nn_error
wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
uint32_t *out_size)
{
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
return get_output(ctx, index, (uint8_t *)out_tensor, *out_size, out_size);
#else
return get_output(ctx, index, (uint8_t *)out_tensor, out_size);
#endif
}

float *
Expand All @@ -108,7 +112,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors)
{
graph graph;

#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
if (wasm_load(model_name, &graph, target) != success) {
#else
if (wasm_load_by_name(model_name, &graph) != success) {
#endif
NN_ERR_PRINTF("Error when loading model.");
exit(1);
}
Expand Down
Loading