@@ -389,23 +389,34 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
389389 return too_large;
390390 }
391391
392- uint32_t model_tensor_size = 1 ;
393- for (int i = 0 ; i < (int )tensor->dims ->size ; ++i)
394- model_tensor_size *= (uint32_t )tensor->dims ->data [i];
395-
396- if (*output_tensor_size < model_tensor_size) {
397- NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
398- return too_large;
399- }
400-
401392 if (tensor->quantization .type == kTfLiteNoQuantization ) {
402393 NN_DBG_PRINTF (" No quantization information" );
403- float *ot =
404- tfl_ctx->interpreters [ctx].interpreter ->typed_output_tensor <float >(
405- index);
406-
407- int size = model_tensor_size * sizeof (float );
408- bh_memcpy_s (output_tensor, size, ot, size);
394+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
395+ if (*output_tensor_size < tensor->bytes ) {
396+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
397+ return too_large;
398+ }
399+ #else
400+ /*
401+ * for now, maintain the bug-to-bug compatibility with the old abi,
402+ * where the size here is the number of fp32, not bytes.
403+ */
404+ if (*output_tensor_size < tensor->bytes / sizeof (float )) {
405+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
406+ return too_large;
407+ }
408+ #endif
409+ bh_memcpy_s (output_tensor, *output_tensor_size, tensor->data .data ,
410+ tensor->bytes );
411+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
412+ *output_tensor_size = tensor->bytes ;
413+ #else
414+ /*
415+ * for now, maintain the bug-to-bug compatibility with the old abi,
416+ * where the size here is the number of fp32, not bytes.
417+ */
418+ *output_tensor_size = tensor->bytes / sizeof (float );
419+ #endif
409420 }
410421 else { // TODO: Assuming uint8 quantized networks.
411422 TfLiteAffineQuantization *quant_info =
@@ -414,6 +425,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
414425 NN_ERR_PRINTF (" Quantization per channel is not supported" );
415426 return runtime_error;
416427 }
428+
429+ uint32_t model_tensor_size = 1 ;
430+ for (int i = 0 ; i < (int )tensor->dims ->size ; ++i)
431+ model_tensor_size *= (uint32_t )tensor->dims ->data [i];
432+
433+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
434+ if (*output_tensor_size / sizeof (float ) < model_tensor_size) {
435+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
436+ return too_large;
437+ }
438+ #else
439+ /*
440+ * for now, maintain the bug-to-bug compatibility with the old abi,
441+ * where the size here is the number of fp32, not bytes.
442+ */
443+ if (*output_tensor_size < model_tensor_size) {
444+ NN_ERR_PRINTF (" Insufficient memory to copy tensor %d" , index);
445+ return too_large;
446+ }
447+ #endif
448+
417449 uint8_t *ot = tfl_ctx->interpreters [ctx]
418450 .interpreter ->typed_output_tensor <uint8_t >(index);
419451
@@ -426,9 +458,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
426458 for (uint32_t i = 0 ; i < model_tensor_size; ++i) {
427459 output_tensor_f[i] = (ot[i] - zero_point) * scale;
428460 }
461+
462+ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
463+ *output_tensor_size = model_tensor_size * sizeof (float );
464+ #else
465+ /*
466+ * for now, maintain the bug-to-bug compatibility with the old abi,
467+ * where the size here is the number of fp32, not bytes.
468+ */
469+ *output_tensor_size = model_tensor_size;
470+ #endif
429471 }
430472
431- *output_tensor_size = model_tensor_size;
432473 return success;
433474}
434475
0 commit comments