@@ -174,7 +174,7 @@ convert_wasi_nn_type_to_ort_type(tensor_type type,
174174#endif
175175 default :
176176 NN_WARN_PRINTF (" Unsupported wasi-nn tensor type: %d" , type);
177- return false ; // Default to float
177+ return false ;
178178 }
179179 return true ;
180180}
@@ -418,13 +418,17 @@ __attribute__((visibility("default"))) wasi_nn_error
418418init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx)
419419{
420420 OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
421+ if (!onnx_ctx) {
422+ return runtime_error;
423+ }
424+
425+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
421426
422427 if (g >= MAX_GRAPHS || !ort_ctx->graphs [g].is_initialized ) {
423428 NN_ERR_PRINTF (" Invalid graph handle: %d" , g);
424429 return invalid_argument;
425430 }
426431
427- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
428432 int ctx_index = -1 ;
429433 for (int i = 0 ; i < MAX_CONTEXTS; i++) {
430434 if (!ort_ctx->exec_ctxs [i].is_initialized ) {
@@ -516,6 +520,11 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
516520 tensor *input_tensor)
517521{
518522 OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
523+ if (!onnx_ctx) {
524+ return runtime_error;
525+ }
526+
527+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
519528
520529 if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs [ctx].is_initialized ) {
521530 NN_ERR_PRINTF (" Invalid execution context handle: %d" , ctx);
@@ -528,7 +537,6 @@ set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
528537 return invalid_argument;
529538 }
530539
531- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
532540 OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs [ctx];
533541
534542 OrtTypeInfo *type_info = nullptr ;
@@ -605,13 +613,17 @@ __attribute__((visibility("default"))) wasi_nn_error
605613compute(void *onnx_ctx, graph_execution_context ctx)
606614{
607615 OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
616+ if (!onnx_ctx) {
617+ return runtime_error;
618+ }
619+
620+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
608621
609622 if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs [ctx].is_initialized ) {
610623 NN_ERR_PRINTF (" Invalid execution context handle: %d" , ctx);
611624 return invalid_argument;
612625 }
613626
614- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
615627 OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs [ctx];
616628
617629 std::vector<OrtValue *> input_values;
@@ -657,6 +669,11 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
657669 tensor_data *out_buffer, uint32_t *out_buffer_size)
658670{
659671 OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
672+ if (!onnx_ctx) {
673+ return runtime_error;
674+ }
675+
676+ std::lock_guard<std::mutex> lock (ort_ctx->mutex );
660677
661678 if (ctx >= MAX_CONTEXTS || !ort_ctx->exec_ctxs [ctx].is_initialized ) {
662679 NN_ERR_PRINTF (" Invalid execution context handle: %d" , ctx);
@@ -669,7 +686,6 @@ get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
669686 return invalid_argument;
670687 }
671688
672- std::lock_guard<std::mutex> lock (ort_ctx->mutex );
673689 OnnxRuntimeExecCtx *exec_ctx = &ort_ctx->exec_ctxs [ctx];
674690
675691 OrtValue *output_value = exec_ctx->outputs [index];
0 commit comments