55
66#include "utils.h"
77#include "logger.h"
8+ #include "wasi_nn.h"
9+
810#include <stdio.h>
911#include <stdlib.h>
1012
11- wasi_ephemeral_nn_error
12- wasm_load (char * model_name , wasi_ephemeral_nn_graph * g ,
13- wasi_ephemeral_nn_execution_target target )
13+ wasi_nn_error
14+ wasm_load (char * model_name , graph * g , execution_target target )
1415{
1516 FILE * pFile = fopen (model_name , "r" );
1617 if (pFile == NULL )
17- return wasi_ephemeral_nn_error_invalid_argument ;
18+ return invalid_argument ;
1819
1920 uint8_t * buffer ;
2021 size_t result ;
@@ -23,122 +24,108 @@ wasm_load(char *model_name, wasi_ephemeral_nn_graph *g,
2324 buffer = (uint8_t * )malloc (sizeof (uint8_t ) * MAX_MODEL_SIZE );
2425 if (buffer == NULL ) {
2526 fclose (pFile );
26- return wasi_ephemeral_nn_error_too_large ;
27+ return too_large ;
2728 }
2829
2930 result = fread (buffer , 1 , MAX_MODEL_SIZE , pFile );
3031 if (result <= 0 ) {
3132 fclose (pFile );
3233 free (buffer );
33- return wasi_ephemeral_nn_error_too_large ;
34+ return too_large ;
3435 }
3536
36- wasi_ephemeral_nn_graph_builder arr ;
37+ graph_builder_array arr ;
38+
39+ arr .size = 1 ;
40+ arr .buf = (graph_builder * )malloc (sizeof (graph_builder ));
41+ if (arr .buf == NULL ) {
42+ fclose (pFile );
43+ free (buffer );
44+ return too_large ;
45+ }
3746
38- arr .buf = buffer ;
39- arr .size = result ;
47+ arr .buf [ 0 ]. size = result ;
48+ arr .buf [ 0 ]. buf = buffer ;
4049
41- wasi_ephemeral_nn_error res = wasi_ephemeral_nn_load (
42- & arr , result , wasi_ephemeral_nn_encoding_tensorflowlite , target , g );
50+ wasi_nn_error res = load (& arr , tensorflowlite , target , g );
4351
4452 fclose (pFile );
4553 free (buffer );
4654 free (arr .buf );
4755 return res ;
4856}
4957
50- wasi_ephemeral_nn_error
51- wasm_load_by_name (const char * model_name , wasi_ephemeral_nn_graph * g )
58+ wasi_nn_error
59+ wasm_load_by_name (const char * model_name , graph * g )
5260{
53- wasi_ephemeral_nn_error res =
54- wasi_ephemeral_nn_load_by_name (model_name , strlen (model_name ), g );
61+ wasi_nn_error res = load_by_name (model_name , strlen (model_name ), g );
5562 return res ;
5663}
5764
58- wasi_ephemeral_nn_error
59- wasm_init_execution_context (wasi_ephemeral_nn_graph g ,
60- wasi_ephemeral_nn_graph_execution_context * ctx )
65+ wasi_nn_error
66+ wasm_init_execution_context (graph g , graph_execution_context * ctx )
6167{
62- return wasi_ephemeral_nn_init_execution_context (g , ctx );
68+ return init_execution_context (g , ctx );
6369}
6470
65- wasi_ephemeral_nn_error
66- wasm_set_input (wasi_ephemeral_nn_graph_execution_context ctx ,
67- float * input_tensor , uint32_t * dim )
71+ wasi_nn_error
72+ wasm_set_input (graph_execution_context ctx , float * input_tensor , uint32_t * dim )
6873{
69- wasi_ephemeral_nn_tensor_dimensions dims ;
74+ tensor_dimensions dims ;
7075 dims .size = INPUT_TENSOR_DIMS ;
7176 dims .buf = (uint32_t * )malloc (dims .size * sizeof (uint32_t ));
7277 if (dims .buf == NULL )
73- return wasi_ephemeral_nn_error_too_large ;
78+ return too_large ;
7479
75- wasi_ephemeral_nn_tensor tensor ;
76- tensor .dimensions = dims ;
77- for (int i = 0 ; i < tensor .dimensions .size ; ++ i )
78- tensor .dimensions .buf [i ] = dim [i ];
79- tensor .type = wasi_ephemeral_nn_type_fp32 ;
80- tensor .data .buf = (uint8_t * )input_tensor ;
81-
82- uint32_t tmp_size = 1 ;
83- if (dim )
84- for (int i = 0 ; i < INPUT_TENSOR_DIMS ; ++ i )
85- tmp_size *= dim [i ];
86-
87- tensor .data .size = (tmp_size * sizeof (float ));
88-
89- wasi_ephemeral_nn_error err = wasi_ephemeral_nn_set_input (ctx , 0 , & tensor );
80+ tensor tensor ;
81+ tensor .dimensions = & dims ;
82+ for (int i = 0 ; i < tensor .dimensions -> size ; ++ i )
83+ tensor .dimensions -> buf [i ] = dim [i ];
84+ tensor .type = fp32 ;
85+ tensor .data = (uint8_t * )input_tensor ;
86+ wasi_nn_error err = set_input (ctx , 0 , & tensor );
9087
9188 free (dims .buf );
9289 return err ;
9390}
9491
95- wasi_ephemeral_nn_error
96- wasm_compute (wasi_ephemeral_nn_graph_execution_context ctx )
92+ wasi_nn_error
93+ wasm_compute (graph_execution_context ctx )
9794{
98- return wasi_ephemeral_nn_compute (ctx );
95+ return compute (ctx );
9996}
10097
101- wasi_ephemeral_nn_error
102- wasm_get_output (wasi_ephemeral_nn_graph_execution_context ctx , uint32_t index ,
103- float * out_tensor , uint32_t * out_size )
98+ wasi_nn_error
99+ wasm_get_output (graph_execution_context ctx , uint32_t index , float * out_tensor ,
100+ uint32_t * out_size )
104101{
105- return wasi_ephemeral_nn_get_output (ctx , index , (uint8_t * )out_tensor ,
106- MAX_OUTPUT_TENSOR_SIZE , out_size );
102+ return get_output (ctx , index , (uint8_t * )out_tensor , out_size );
107103}
108104
109105float *
110- run_inference (float * input , uint32_t * input_size , uint32_t * output_size ,
111- char * model_name , uint32_t num_output_tensors )
106+ run_inference (float * input , uint32_t * input_size ,
107+ uint32_t * output_size , char * model_name ,
108+ uint32_t num_output_tensors )
112109{
113- wasi_ephemeral_nn_graph graph ;
110+ graph graph ;
114111
115- wasi_ephemeral_nn_error res = wasm_load_by_name (model_name , & graph );
116-
117- if (res == wasi_ephemeral_nn_error_not_found ) {
118- NN_INFO_PRINTF ("Model %s is not loaded, you should pass its path "
119- "through --wasi-nn-graph" ,
120- model_name );
121- return NULL ;
122- }
123- else if (res != wasi_ephemeral_nn_error_success ) {
112+ if (wasm_load_by_name (model_name , & graph ) != success ) {
124113 NN_ERR_PRINTF ("Error when loading model." );
125114 exit (1 );
126115 }
127116
128- wasi_ephemeral_nn_graph_execution_context ctx ;
129- if (wasm_init_execution_context (graph , & ctx )
130- != wasi_ephemeral_nn_error_success ) {
117+ graph_execution_context ctx ;
118+ if (wasm_init_execution_context (graph , & ctx ) != success ) {
131119 NN_ERR_PRINTF ("Error when initialixing execution context." );
132120 exit (1 );
133121 }
134122
135- if (wasm_set_input (ctx , input , input_size )
136- != wasi_ephemeral_nn_error_success ) {
123+ if (wasm_set_input (ctx , input , input_size ) != success ) {
137124 NN_ERR_PRINTF ("Error when setting input tensor." );
138125 exit (1 );
139126 }
140127
141- if (wasm_compute (ctx ) != wasi_ephemeral_nn_error_success ) {
128+ if (wasm_compute (ctx ) != success ) {
142129 NN_ERR_PRINTF ("Error when running inference." );
143130 exit (1 );
144131 }
@@ -153,7 +140,7 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size,
153140 for (int i = 0 ; i < num_output_tensors ; ++ i ) {
154141 * output_size = MAX_OUTPUT_TENSOR_SIZE - * output_size ;
155142 if (wasm_get_output (ctx , i , & out_tensor [offset ], output_size )
156- != wasi_ephemeral_nn_error_success ) {
143+ != success ) {
157144 NN_ERR_PRINTF ("Error when getting index %d." , i );
158145 break ;
159146 }
0 commit comments