Skip to content

Commit 59bd4af

Browse files
committed
wasi_nn_tensorflowlite.cpp: fix get_output return size
it should be byte size, not the number of (fp32) values. i'm ambivalent about how to deal with the compatibility for the legacy wamr-specific "wasi_nn". for now, i avoided changing it. (so that existing tests using the legacy abi, namely test_tensorflow.c and test_tensorflow_quantized.c, pass as they are.) if we have any users who still want to use the legacy abi, i suppose they consider the compatibility is more important than the consistency with other backends. cf. #4376
1 parent a29f394 commit 59bd4af

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -389,23 +389,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
389389
return too_large;
390390
}
391391

392-
uint32_t model_tensor_size = 1;
393-
for (int i = 0; i < (int)tensor->dims->size; ++i)
394-
model_tensor_size *= (uint32_t)tensor->dims->data[i];
395-
396-
if (*output_tensor_size < model_tensor_size) {
397-
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
398-
return too_large;
399-
}
400-
401392
if (tensor->quantization.type == kTfLiteNoQuantization) {
402393
NN_DBG_PRINTF("No quantization information");
403-
float *ot =
404-
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
405-
index);
406-
407-
int size = model_tensor_size * sizeof(float);
408-
bh_memcpy_s(output_tensor, size, ot, size);
394+
if (*output_tensor_size < tensor->bytes) {
395+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
396+
return too_large;
397+
}
398+
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
399+
tensor->bytes);
400+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
401+
*output_tensor_size = tensor->bytes;
402+
#else
403+
*output_tensor_size = tensor->bytes / sizeof(float);
404+
#endif
409405
}
410406
else { // TODO: Assuming uint8 quantized networks.
411407
TfLiteAffineQuantization *quant_info =
@@ -414,6 +410,16 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
414410
NN_ERR_PRINTF("Quantization per channel is not supported");
415411
return runtime_error;
416412
}
413+
414+
uint32_t model_tensor_size = 1;
415+
for (int i = 0; i < (int)tensor->dims->size; ++i)
416+
model_tensor_size *= (uint32_t)tensor->dims->data[i];
417+
418+
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
419+
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
420+
return too_large;
421+
}
422+
417423
uint8_t *ot = tfl_ctx->interpreters[ctx]
418424
.interpreter->typed_output_tensor<uint8_t>(index);
419425

@@ -426,9 +432,14 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
426432
for (uint32_t i = 0; i < model_tensor_size; ++i) {
427433
output_tensor_f[i] = (ot[i] - zero_point) * scale;
428434
}
435+
436+
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
437+
*output_tensor_size = model_tensor_size * sizeof(float);
438+
#else
439+
*output_tensor_size = model_tensor_size;
440+
#endif
429441
}
430442

431-
*output_tensor_size = model_tensor_size;
432443
return success;
433444
}
434445

0 commit comments

Comments
 (0)