@@ -102,91 +102,81 @@ convert_ort_error_to_wasi_nn_error(OrtStatus *status)
102102 return err;
103103}
104104
105- static tensor_type
106- convert_ort_type_to_wasi_nn_type (ONNXTensorElementDataType ort_type)
105+ static bool
106+ convert_ort_type_to_wasi_nn_type (ONNXTensorElementDataType ort_type, tensor_type *tensor_type )
107107{
108108 switch (ort_type) {
109109 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
110- return fp32;
110+ *tensor_type = fp32;
111+ break ;
111112 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
112- return fp16;
113+ *tensor_type = fp16;
114+ break ;
113115#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
114116 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
115- return fp64;
117+ *tensor_type = fp64;
118+ break ;
116119 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
117- return u8 ;
120+ *tensor_type = u8 ;
121+ break ;
118122 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
119- return i32 ;
123+ *tensor_type = i32 ;
124+ break ;
120125 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
121- return i64 ;
126+ *tensor_type = i64 ;
127+ break ;
122128#else
123129 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
124- return up8;
130+ *tensor_type = up8;
131+ break ;
125132 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
126- return ip32;
133+ *tensor_type = ip32;
134+ break ;
127135#endif
128136 default :
129137 NN_WARN_PRINTF (" Unsupported ONNX tensor type: %d" , ort_type);
130- return fp32; // Default to fp32
138+ return false ;
131139 }
132- }
133140
134- static ONNXTensorElementDataType
135- convert_wasi_nn_type_to_ort_type (tensor_type type)
136- {
137- switch (type) {
138- case fp32:
139- return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
140- case fp16:
141- return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
142- #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
143- case fp64:
144- return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
145- case u8 :
146- return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
147- case i32 :
148- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
149- case i64 :
150- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
151- #else
152- case up8:
153- return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
154- case ip32:
155- return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
156- #endif
157- default :
158- NN_WARN_PRINTF (" Unsupported wasi-nn tensor type: %d" , type);
159- return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float
160- }
141+ return true ;
161142}
162143
163- static size_t
164- get_tensor_element_size (tensor_type type)
144+ static bool
145+ convert_wasi_nn_type_to_ort_type (tensor_type type, ONNXTensorElementDataType *ort_type )
165146{
166147 switch (type) {
167148 case fp32:
168- return 4 ;
149+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
150+ break ;
169151 case fp16:
170- return 2 ;
152+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
153+ break ;
171154#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
172155 case fp64:
173- return 8 ;
156+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
157+ break ;
174158 case u8 :
175- return 1 ;
159+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
160+ break ;
176161 case i32 :
177- return 4 ;
162+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
163+ break ;
178164 case i64 :
179- return 8 ;
165+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
166+ break ;
180167#else
181168 case up8:
182- return 1 ;
169+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
170+ break ;
183171 case ip32:
184- return 4 ;
172+ *ort_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
173+ break ;
185174#endif
186175 default :
187- NN_WARN_PRINTF (" Unsupported tensor type: %d" , type);
188- return 4 ; // Default to 4 bytes ( float)
176+ NN_WARN_PRINTF (" Unsupported wasi-nn tensor type: %d" , type);
177+ return false ; // Default to float
189178 }
179+ return true ;
190180}
191181
192182/* Backend API implementation */
@@ -579,8 +569,12 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
579569 ort_dims[i] = input_tensor->dimensions ->buf [i];
580570 }
581571
582- ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type (
583- static_cast <tensor_type>(input_tensor->type ));
572+ ONNXTensorElementDataType ort_type;
573+ if (!convert_wasi_nn_type_to_ort_type (
574+ static_cast <tensor_type>(input_tensor->type ), &ort_type)) {
575+ NN_ERR_PRINTF (" Failed to convert tensor type" );
576+ return runtime_error;
577+ }
584578
585579 OrtValue *input_value = nullptr ;
586580 size_t total_elements = 1 ;
@@ -589,9 +583,7 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
589583 }
590584
591585 status = ort_ctx->ort_api ->CreateTensorWithDataAsOrtValue (
592- exec_ctx->memory_info , input_tensor->data .buf ,
593- get_tensor_element_size (static_cast <tensor_type>(input_tensor->type ))
594- * total_elements,
586+ exec_ctx->memory_info , input_tensor->data .buf ,input_tensor->data .size ,
595587 ort_dims, num_dims, ort_type, &input_value);
596588
597589 free (ort_dims);
@@ -793,18 +785,16 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
793785 }
794786
795787 size_t output_size_bytes = tensor_size * element_size;
796-
797- NN_INFO_PRINTF (" Output tensor size: %zu elements, element size: %zu bytes, "
798- " total: %zu bytes" ,
799- tensor_size, element_size, output_size_bytes);
800-
801- if (*out_buffer_size < output_size_bytes) {
788+ if (out_buffer->size < output_size_bytes) {
802789 NN_ERR_PRINTF (
803790 " Output buffer too small: %u bytes provided, %zu bytes needed" ,
804- *out_buffer_size , output_size_bytes);
791+ out_buffer-> size , output_size_bytes);
805792 *out_buffer_size = output_size_bytes;
806- return invalid_argument ;
793+ return too_large ;
807794 }
795+ NN_INFO_PRINTF (" Output tensor size: %zu elements, element size: %zu bytes, "
796+ " total: %zu bytes" ,
797+ tensor_size, element_size, output_size_bytes);
808798
809799 if (tensor_data == nullptr ) {
810800 NN_ERR_PRINTF (" Tensor data is null" );
0 commit comments