Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/iwasm/libraries/wasi-nn/include/wasi_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn")));

wasi_nn_error
load_by_name(const char *name, graph *g)
load_by_name(char *name, uint32_t name_len, graph *g)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug fix?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasi_nn.h is a header for WebAssembly applications written in the C language. Is there a specific reason that we need to change it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, its a bugfix,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have two sets of APIs for historical reasons, we might remove one in another PR. For now, let's ensure both are functional.

  • I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side.
  • With this flag, we declare two sets of APIs in wasi_nn.h. Please align with the content of native_symbols_wasi_nn in wasi_nn.c

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lum1n0us
Where can I find the function prototypes for wasi_ephemeral_nn?
For functions like get_output, the signatures are different, so simply replacing wasi_nn with wasi_ephemeral_nn mechanically doesn't seem to work.

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
REG_NATIVE_FUNC(load, "(*iii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(load_by_name_with_config, "(*i*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii**)i"),
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

Even when referring to the wasi-nn specification, the signatures declared there don’t appear to match what is used for wasi_ephemeral_nn
https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L59-L86

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HongxiaWangSSSS I suggest we use WASM_ENABLE_WASI_EPHEMERAL_NN on the wasm side and there will be two sets for wasi_ephemeral_nn and wasi_nn. wasi_nn will be deprecated in the future.

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
     execution_target target, graph *g)
    __attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name(const char *name, uint32_t len, graph *g)
    __attribute__((import_module("wasi_ephemeral_nn")));

wasi_nn_error
load_by_name_with_config(const char *name, uint32_t name_len, void *config, uint32_t config_len, graph *g)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
init_execution_context(graph g, graph_execution_context *exec_ctx)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
compute(graph_execution_context ctx)     __attribute__((import_module("wasi_ephemeral_nn")));;

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t *output_tensor_size)     __attribute__((import_module("wasi_ephemeral_nn")));;

#else

wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
     execution_target target, graph *g)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
init_execution_context(graph g, graph_execution_context *ctx)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
    __attribute__((import_module("wasi_nn")));

wasi_nn_error
compute(graph_execution_context ctx) __attribute__((import_module("wasi_nn")));

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t *output_tensor_size)
    __attribute__((import_module("wasi_nn")));
#endif

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a mismatch in the function signature?

REG_NATIVE_FUNC(get_output, "(ii*i*)i"),

and

wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size) attribute((import_module("wasi_ephemeral_nn")));

Copy link
Copy Markdown
Contributor

@lum1n0us lum1n0us May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YES. there should be two versions, one for wasi_ephemeral_nn, another for wasi_nn. please refer to:

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
             uint32_t builder_wasm_size, graph_encoding encoding,
             execution_target target, graph *g)
#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
             graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */


#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                   uint32_t index, tensor_data output_tensor,
                   uint32_t output_tensor_len, uint32_t *output_tensor_size)
#else  /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
                   uint32_t index, tensor_data output_tensor,
                   uint32_t *output_tensor_size)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, we need to define it at wasi_nn.h

wasi_nn_error
 get_output(graph_execution_context ctx, uint32_t index,
           tensor_data output_tensor, uint32_t output_tensor_len, uint32_t *output_tensor_size)     __attribute__((import_module("wasi_ephemeral_nn")));

But the backend definition doesn't look like it matches (signature might be ok)

__attribute__((visibility("default"))) wasi_nn_error
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)

Copy link
Copy Markdown
Contributor

@lum1n0us lum1n0us May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I am suggestion this #4267 (comment) as the new content of wasi_nn.h. Plus, #4267 (comment).

__attribute__((import_module("wasi_nn")));

/**
Expand Down
7 changes: 2 additions & 5 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,11 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error;
}
if (tfl_ctx->models[g].model_pointer == NULL) {
if (tfl_ctx->models[g].model_pointer == NULL
&& tfl_ctx->models[g].model == NULL) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original version can output different information based on various invalid argument cases. Is there a specific reason we need to merge them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is required to validate TFLitesContext.models[g] for both cases, using load() and load_by_name(). It will not be acceptable if the change disables one of these cases.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether it is load or load_by_name, the check of models[g].model_pointerdoes not seem to be necessary, just make sure the models[g].model is not NULL maybe is enough.
Do you have any idea?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does not seem to be necessary.

Why is that?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this operation, the models[g].model_pointer's connect has been saved in models[g].model.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/tensorflow/tensorflow/blob/02896e880298894beedc71f8666f6949ad0e174f/tensorflow/compiler/mlir/lite/core/model_builder_base.h#L169-#L184

  /// Builds a model based on a pre-loaded flatbuffer.
  /// Caller retains ownership of the buffer and should keep it alive until
  /// the returned object is destroyed. Caller also retains ownership of
  /// `error_reporter` and must ensure its lifetime is longer than the
  /// FlatBufferModelBase instance.
  /// Returns a nullptr in case of failure.
  /// NOTE: this does NOT validate the buffer so it should NOT be called on
  /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case
  static std::unique_ptr<T> BuildFromBuffer(
      const char* caller_owned_buffer, size_t buffer_size,
      ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) {
    error_reporter = ValidateErrorReporter(error_reporter);
    std::unique_ptr<Allocation> allocation(
        new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter));
    return BuildFromAllocation(std::move(allocation), error_reporter);
  }

If I understand correctly, model_pointer acts as a pre-allocated buffer, and its ownership is still held by the caller, in our case, tfl_ctx. Meanwhile, model holds the ownership of the result from tflite::FlatBufferModel::BuildFromBuffer(). Therefore, both are required.

NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
if (tfl_ctx->models[g].model == NULL) {
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
return runtime_error;
}
return success;
}

Expand Down
7 changes: 6 additions & 1 deletion core/iwasm/libraries/wasi-nn/test/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
wasi_nn_error res = load_by_name(model_name, g);
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better be

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
{
    wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
    return res;
}
#endif 

return res;
}

Expand Down Expand Up @@ -108,7 +108,12 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors)
{
graph graph;

#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
if (wasm_load(model_name, &graph, target) != success) {
#else
if (wasm_load_by_name(model_name, &graph) != success) {
#endif
NN_ERR_PRINTF("Error when loading model.");
exit(1);
}
Expand Down
Loading