@@ -347,19 +347,13 @@ Error platform_execute(
347347 int output_count,
348348 Span<executorch::runtime::EValue*> args,
349349 char * /* ethosu_scratch*/ ) {
350- std::vector<size_t > input_copy_sizes;
351- std::vector<const char *> linux_input_ptrs;
352- if (input_count > 0 ) {
353- input_copy_sizes.resize (input_count, 0 );
354- linux_input_ptrs.resize (input_count, nullptr );
355- }
350+ std::vector<size_t > input_copy_sizes (input_count, 0 );
351+ std::vector<const char *> linux_input_ptrs (input_count, nullptr );
356352
357- std::vector<size_t > output_io_bytes;
358- std::vector<char *> linux_output_ptrs;
359- if (output_count > 0 ) {
360- output_io_bytes.resize (output_count, 0 );
361- linux_output_ptrs.resize (output_count, nullptr );
362- }
353+ std::vector<size_t > output_io_bytes (output_count, 0 );
354+ std::vector<char *> linux_output_ptrs (output_count, nullptr );
355+ std::vector<std::vector<char >> output_scratch_buffers (output_count);
356+ std::vector<bool > output_needs_adjustment (output_count, false );
363357
364358 for (int i = 0 ; i < input_count; ++i) {
365359 auto tensor_in = args[i]->toTensor ();
@@ -380,16 +374,12 @@ Error platform_execute(
380374 const size_t tensor_nbytes = tensor_out.nbytes ();
381375 if (i < static_cast <int >(output_io_bytes.size ()) &&
382376 output_io_bytes[i] != tensor_nbytes) {
383- ET_LOG (
384- Error,
385- " Ethos-U Linux backend output size mismatch for index %d: "
386- " driver IO bytes = %zu, tensor bytes = %zu" ,
387- i,
388- output_io_bytes[i],
389- tensor_nbytes);
390- return Error::InvalidState;
377+ output_scratch_buffers[i].resize (output_io_bytes[i]);
378+ linux_output_ptrs[i] = output_scratch_buffers[i].data ();
379+ output_needs_adjustment[i] = true ;
380+ } else {
381+ linux_output_ptrs[i] = tensor_out.mutable_data_ptr <char >();
391382 }
392- linux_output_ptrs[i] = tensor_out.mutable_data_ptr <char >();
393383 }
394384 }
395385
@@ -399,13 +389,37 @@ Error platform_execute(
399389 return Error::InvalidState;
400390 }
401391
402- return invoke_linux_driver (
392+ Error status = invoke_linux_driver (
403393 handles,
404394 linux_input_ptrs,
405395 linux_output_ptrs,
406396 input_copy_sizes,
407397 output_io_bytes,
408398 state->options );
399+ if (status != Error::Ok) {
400+ return status;
401+ }
402+
403+ if (handles.outputs != nullptr ) {
404+ for (int i = 0 ; i < output_count; ++i) {
405+ if (!output_needs_adjustment[i]) {
406+ continue ;
407+ }
408+ auto tensor_out = args[input_count + i]->toTensor ();
409+ const size_t tensor_nbytes = tensor_out.nbytes ();
410+ Error adjust_status = copy_with_layout_adjustment (
411+ handles.outputs ->io [i],
412+ i,
413+ output_scratch_buffers[i].data (),
414+ tensor_out,
415+ tensor_nbytes);
416+ if (adjust_status != Error::Ok) {
417+ return adjust_status;
418+ }
419+ }
420+ }
421+
422+ return Error::Ok;
409423}
410424
411425} // namespace arm
0 commit comments