diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index 0bde69fb1d7..103a8812931 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -127,7 +127,7 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { xnn_status_to_string(status)); } } - // // Propagate Input Shape and Memory Plan for increased allocation + // Propagate Input Shape and Memory Plan for increased allocation status = xnn_reshape_runtime(runtime_.get()); ET_CHECK_OR_RETURN_ERROR( @@ -136,6 +136,12 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { "Internal Error: Propagating input shapes failed with code: %s", xnn_status_to_string(status)); + // Resize output tensors. + Error err = resize_outputs(args); + if (err != Error::Ok) { + return err; + } + return Error::Ok; } @@ -188,14 +194,7 @@ ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) { } /** - * Prepares the outputs for ExecuTorch - * - * Resizes the output tensors based on the output shapes returned by - * the xnnpack runtime. - * - * Note: For arg_max pooling, we recast the output index tensor. Since - * XNNPACK gives the index tensor to us as int32, we need to convert it - * back to int64 for ExecuTorch. + * Resizes output tensors to match XNNPACK's computed shapes. */ ET_NODISCARD Error XNNExecutor::resize_outputs(Span args) const { size_t output_idx_start = input_ids_.size(); @@ -239,6 +238,22 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(Span args) const { ET_LOG(Error, "Failed to resize output tensor for XNNExecutor"); return err; } + } + + return Error::Ok; +} + +/** + * Converts output data types after XNNPACK execution. + * + * For arg_max pooling, XNNPACK outputs int32 index tensors that need + * to be converted to int64 for ExecuTorch. + */ +ET_NODISCARD Error XNNExecutor::convert_outputs(Span args) const { + size_t output_idx_start = input_ids_.size(); + for (size_t i = output_idx_start; i < externals_.size(); ++i) { + uint32_t ext_id = externals_[i].id; + Tensor* out_tensor = &args[ext_id]->toTensor(); // Output datatype is int64. However, XNNPACK doesn't support // int64. This means that the data was put into this tensor diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index 6c07771b02a..fa7c8360be4 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -88,13 +88,21 @@ class XNNExecutor { executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext& context); /** - * Prepares the outputs to be returned by the delegate + * Resizes output tensors to match XNNPACK's computed shapes. * - * Performs any post processing of outputs like tensor resizing */ ET_NODISCARD executorch::runtime::Error resize_outputs( executorch::runtime::Span args) const; + /** + * Converts output data types after XNNPACK execution. + * + * For arg_max pooling, XNNPACK outputs int32 index tensors that need + * to be converted to int64 for ExecuTorch. + */ + ET_NODISCARD executorch::runtime::Error convert_outputs( + executorch::runtime::Span args) const; + friend class XNNCompiler; }; diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index b0e7cd66f49..23a3f4c4b1f 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -152,8 +152,8 @@ class XnnpackBackend final return err; } - // Resize outputs and recast pointers if necessary - err = executor->resize_outputs(args); + // Convert output data types if necessary (e.g., int32 -> int64 for Long) + err = executor->convert_outputs(args); return err; } diff --git a/backends/xnnpack/test/runtime/test_xnnexecutor.cpp b/backends/xnnpack/test/runtime/test_xnnexecutor.cpp index 1963fd15d05..63b9c096b2b 100644 --- a/backends/xnnpack/test/runtime/test_xnnexecutor.cpp +++ b/backends/xnnpack/test/runtime/test_xnnexecutor.cpp @@ -174,7 +174,7 @@ TEST(XNNExecutorTest, ResizeOutputsWithLongTensorConvertsInt32ToInt64) { ASSERT_EQ(executor.prepare_args(span), Error::Ok); executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext context; ASSERT_EQ(executor.forward(context), Error::Ok); - ASSERT_EQ(executor.resize_outputs(span), Error::Ok); + ASSERT_EQ(executor.convert_outputs(span), Error::Ok); Tensor& result = args[2]->toTensor(); ASSERT_EQ(result.scalar_type(), executorch::aten::ScalarType::Long);