diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 4a53519f39a..0e152cbea62 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -433,6 +433,8 @@ - op: var.correction_out +- op: var_mean.correction_out + - op: var.out - op: view_as_real_copy.out diff --git a/kernels/portable/cpu/op_var_mean.cpp b/kernels/portable/cpu/op_var_mean.cpp new file mode 100644 index 00000000000..6e72b268610 --- /dev/null +++ b/kernels/portable/cpu/op_var_mean.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { +namespace { + +template +void compute_var_mean( + KernelRuntimeContext& ctx, + const Tensor& in, + Tensor& var_out, + Tensor& mean_out, + optional> dim_list, + const size_t num, + const double denominator) { + CTYPE_OUT* var_data = var_out.mutable_data_ptr(); + CTYPE_OUT* mean_data = mean_out.mutable_data_ptr(); + if (num == 0 || denominator <= 0) { + for (const auto out_ix : c10::irange(var_out.numel())) { + var_data[out_ix] = NAN; + mean_data[out_ix] = NAN; + } + } else if (in.numel() > 0) { + MapReduceOverDimListPlan plan(in, dim_list); + const bool success = parallel_for_each_reduce_over_dim_list_output_index( + in, dim_list, var_out, [&](const auto begin, const auto end) { + for (const auto out_ix : c10::irange(begin, end)) { + // Pass 1: compute sum -> mean + CTYPE_OUT sum = plan.execute( + [](CTYPE_IN v) { return static_cast(v); }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, + out_ix); + CTYPE_OUT mean = sum / static_cast(num); + mean_data[out_ix] = mean; + // Pass 2: compute sum of squared deviations + CTYPE_OUT sum2 = plan.execute( + [mean](CTYPE_IN v) { + return ( + (static_cast(v) - mean) * + (static_cast(v) - mean)); + }, + [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, + out_ix); + var_data[out_ix] = sum2 / denominator; + } + }); + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); + } +} + +} // namespace + +std::tuple var_mean_correction_out( + KernelRuntimeContext& ctx, + const Tensor& in, + optional> dim_list, + const optional& correction, + bool keepdim, + Tensor& out0, + Tensor& out1) { + (void)ctx; + + std::tuple ret_val(out0, out1); + + ET_KERNEL_CHECK( + ctx, + check_reduction_args(in, dim_list, keepdim, {}, out0), + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + check_reduction_args(in, dim_list, keepdim, {}, out1), + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + resize_reduction_out(in, dim_list, keepdim, out0) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + resize_reduction_out(in, dim_list, keepdim, out1) == Error::Ok, + InvalidArgument, + ret_val); + + static constexpr auto name = "var_mean.correction_out"; + + double correction_val = 1; + if (correction.has_value()) { + correction_val = utils::scalar_to(correction.value()); + } + + const size_t num = get_reduced_dim_product(in, dim_list); + const double denom = num - correction_val; + + ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] { + ET_SWITCH_FLOATHBF16_TYPES(out0.scalar_type(), ctx, name, CTYPE_OUT, [&] { + compute_var_mean( + ctx, in, out0, out1, dim_list, num, denom); + }); + }); + + return ret_val; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index e730a993984..620d97d050f 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -1015,6 +1015,11 @@ - arg_meta: null kernel_name: torch::executor::var_correction_out +- op: var_mean.correction_out + kernels: + - arg_meta: null + kernel_name: torch::executor::var_mean_correction_out + - op: var.out kernels: - arg_meta: null diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index d685d291709..2707ba5db71 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -311,6 +311,7 @@ set(all_test_sources "op_upsample_bilinear2d_aa_test.cpp" "op_upsample_nearest2d_test.cpp" "op_var_test.cpp" + "op_var_mean_test.cpp" "op_view_as_real_copy_test.cpp" "op_view_copy_test.cpp" "op_where_test.cpp" diff --git a/kernels/test/op_var_mean_test.cpp b/kernels/test/op_var_mean_test.cpp new file mode 100644 index 00000000000..7049c21d65b --- /dev/null +++ b/kernels/test/op_var_mean_test.cpp @@ -0,0 +1,705 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace ::testing; +using executorch::aten::ArrayRef; +using executorch::aten::Scalar; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using std::optional; +using torch::executor::testing::TensorFactory; + +namespace { +void expect_tensor_close_with_increased_tol( + const Tensor& actual, + const Tensor& expected) { + if (actual.scalar_type() == ScalarType::BFloat16 || + actual.scalar_type() == ScalarType::Half) { + EXPECT_TENSOR_CLOSE_WITH_TOL(expected, actual, 1e-2, 1e-2); + } else { + EXPECT_TENSOR_CLOSE(expected, actual); + } +} +} // namespace + +class OpVarMeanCorrectionOutTest : public OperatorTest { + protected: + std::tuple op_var_mean_correction_out( + const Tensor& self, + optional> dim, + optional& correction, + bool keepdim, + Tensor& out0, + Tensor& out1) { + return torch::executor::aten::var_mean_outf( + context_, self, dim, correction, keepdim, out0, out1); + } + + template + void test_dtype() { + TensorFactory tf; + + Tensor x = tf.make({2, 3}, {4.9, 4.0, 5.6, 3.8, 4.9, 5.6}); + Tensor expected_var = tf.make({2}, {0.72693, 0.93032}); + Tensor expected_mean = tf.make({2}, {4.833333, 4.766667}); + optional correction(1.23); + Tensor var_out = tf.zeros({2}); + Tensor mean_out = tf.zeros({2}); + + op_var_mean_correction_out( + x, + ArrayRef{1}, + correction, + /*keepdim=*/false, + var_out, + mean_out); + expect_tensor_close_with_increased_tol(var_out, expected_var); + expect_tensor_close_with_increased_tol(mean_out, expected_mean); + } + + template + void test_keepdim() { + TensorFactory tf_in; + TensorFactory tf_out; + + // clang-format off + Tensor self = tf_in.make( + {2, 3, 4}, + { + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + }); + // clang-format on + + // keepdim=true + Tensor var_out = tf_out.zeros({2, 3, 1}); + Tensor mean_out = tf_out.zeros({2, 3, 1}); + int64_t dims_1[1] = {2}; + optional> optional_dim_list{ArrayRef{dims_1, 1}}; + optional correction(1); + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out); + // clang-format off + expect_tensor_close_with_increased_tol(var_out, tf_out.make( + {2, 3, 1}, + { + 1.666667, + 1.666667, + 1.666667, + + 1.666667, + 1.666667, + 1.666667, + })); + expect_tensor_close_with_increased_tol(mean_out, tf_out.make( + {2, 3, 1}, + { + 1.5, + 5.5, + 9.5, + + 13.5, + 17.5, + 21.5, + })); + // clang-format on + + // keepdim=false + var_out = tf_out.zeros({2, 3}); + mean_out = tf_out.zeros({2, 3}); + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/false, + var_out, + mean_out); + // clang-format off + expect_tensor_close_with_increased_tol(var_out, tf_out.make( + {2, 3}, + { + 1.666667, 1.666667, 1.666667, + 1.666667, 1.666667, 1.666667, + })); + expect_tensor_close_with_increased_tol(mean_out, tf_out.make( + {2, 3}, + { + 1.5, 5.5, 9.5, + 13.5, 17.5, 21.5, + })); + // clang-format on + } + + template + void test_multiple_dims() { + TensorFactory tf_in; + TensorFactory tf_out; + + // clang-format off + Tensor self = tf_in.make( + {2, 3, 4}, + { + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + }); + // clang-format on + + Tensor var_out = tf_out.zeros({1, 1, 4}); + Tensor mean_out = tf_out.zeros({1, 1, 4}); + int64_t dims[2] = {0, 1}; + optional> optional_dim_list{ArrayRef{dims, 2}}; + optional correction(1); + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out); + expect_tensor_close_with_increased_tol( + var_out, tf_out.make({1, 1, 4}, {56.0, 56.0, 56.0, 56.0})); + expect_tensor_close_with_increased_tol( + mean_out, tf_out.make({1, 1, 4}, {10.0, 11.0, 12.0, 13.0})); + + var_out = tf_out.zeros({4}); + mean_out = tf_out.zeros({4}); + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/false, + var_out, + mean_out); + expect_tensor_close_with_increased_tol( + var_out, tf_out.make({4}, {56.0, 56.0, 56.0, 56.0})); + expect_tensor_close_with_increased_tol( + mean_out, tf_out.make({4}, {10.0, 11.0, 12.0, 13.0})); + } + + template + void test_negative_dim() { + TensorFactory tf_in; + TensorFactory tf_out; + + // clang-format off + Tensor self = tf_in.make( + {2, 3, 4}, + { + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + }); + // clang-format on + + Tensor var_out = tf_out.zeros({2, 1, 4}); + Tensor mean_out = tf_out.zeros({2, 1, 4}); + int64_t dims[1] = {-2}; + optional> optional_dim_list{ArrayRef{dims, 1}}; + optional correction(0); + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out); + // clang-format off + expect_tensor_close_with_increased_tol(var_out, tf_out.make( + {2, 1, 4}, + { + 10.666667, 10.666667, 10.666667, 10.666667, + + 10.666667, 10.666667, 10.666667, 10.666667, + })); + expect_tensor_close_with_increased_tol(mean_out, tf_out.make( + {2, 1, 4}, + { + 4.0, 5.0, 6.0, 7.0, + + 16.0, 17.0, 18.0, 19.0, + })); + // clang-format on + } + + template + void test_null_and_empty_dim_list() { + TensorFactory tf_in; + TensorFactory tf_out; + + // clang-format off + Tensor self = tf_in.make( + {2, 3, 4}, + { + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + }); + // clang-format on + + // null dim list, correction=1 (unbiased), keepdim=true + Tensor var_out = tf_out.zeros({1, 1, 1}); + Tensor mean_out = tf_out.zeros({1, 1, 1}); + optional> null_dim_list; + optional correction(1); + op_var_mean_correction_out( + self, null_dim_list, correction, /*keepdim=*/true, var_out, mean_out); + expect_tensor_close_with_increased_tol( + var_out, tf_out.make({1, 1, 1}, {50.0})); + expect_tensor_close_with_increased_tol( + mean_out, tf_out.make({1, 1, 1}, {11.5})); + + // empty dim list, correction=0 (population), keepdim=true + optional> empty_dim_list{ArrayRef{}}; + optional correction_zero(0); + op_var_mean_correction_out( + self, + empty_dim_list, + correction_zero, + /*keepdim=*/true, + var_out, + mean_out); + expect_tensor_close_with_increased_tol( + var_out, tf_out.make({1, 1, 1}, {47.916668})); + expect_tensor_close_with_increased_tol( + mean_out, tf_out.make({1, 1, 1}, {11.5})); + + // null dim list, correction=0, keepdim=false + var_out = tf_out.zeros({}); + mean_out = tf_out.zeros({}); + op_var_mean_correction_out( + self, + null_dim_list, + correction_zero, + /*keepdim=*/false, + var_out, + mean_out); + expect_tensor_close_with_increased_tol( + var_out, tf_out.make({}, {47.916668})); + expect_tensor_close_with_increased_tol(mean_out, tf_out.make({}, {11.5})); + + // empty dim list, correction=1, keepdim=false + op_var_mean_correction_out( + self, + empty_dim_list, + correction, + /*keepdim=*/false, + var_out, + mean_out); + expect_tensor_close_with_increased_tol(var_out, tf_out.make({}, {50.0})); + expect_tensor_close_with_increased_tol(mean_out, tf_out.make({}, {11.5})); + } + + template + void test_invalid_dimensions() { + TensorFactory tf_in; + TensorFactory tf_out; + + // clang-format off + Tensor self = tf_in.make( + {2, 3, 4}, + { + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + }); + // clang-format on + Tensor var_out = tf_out.zeros({2, 3, 1}); + Tensor mean_out = tf_out.zeros({2, 3, 1}); + optional correction(1); + + // out-of-bound dim + int64_t dims_1[1] = {3}; + optional> optional_dim_list{ArrayRef{dims_1, 1}}; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out)); + + // duplicate dim + int64_t dims_2[2] = {2, 2}; + optional_dim_list = ArrayRef{dims_2, 2}; + ET_EXPECT_KERNEL_FAILURE( + context_, + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out)); + } +}; + +TEST_F(OpVarMeanCorrectionOutTest, SmokeTest) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpVarMeanCorrectionOutTest, KeepDim) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen supports fewer dtypes"; + } +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_keepdim(); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +TEST_F(OpVarMeanCorrectionOutTest, KeepDim_Aten) { + if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen-specific variant of test case"; + } +#define TEST_ENTRY(CTYPE, DTYPE) \ + test_keepdim(); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpVarMeanCorrectionOutTest, MultipleDims) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen supports fewer dtypes"; + } +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_multiple_dims(); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +TEST_F(OpVarMeanCorrectionOutTest, MultipleDims_Aten) { + if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen-specific variant of test case"; + } +#define TEST_ENTRY(CTYPE, DTYPE) \ + test_multiple_dims(); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpVarMeanCorrectionOutTest, NegativeDim) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen supports fewer dtypes"; + } +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_negative_dim(); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +TEST_F(OpVarMeanCorrectionOutTest, NegativeDim_Aten) { + if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen-specific variant of test case"; + } +#define TEST_ENTRY(CTYPE, DTYPE) \ + test_negative_dim(); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpVarMeanCorrectionOutTest, NullAndEmptyDimList) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen supports fewer dtypes"; + } +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_null_and_empty_dim_list< \ + ScalarType::INPUT_DTYPE, \ + ScalarType::OUTPUT_DTYPE>(); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +TEST_F(OpVarMeanCorrectionOutTest, NullAndEmptyDimList_Aten) { + if (!torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen-specific variant of test case"; + } +#define TEST_ENTRY(CTYPE, DTYPE) \ + test_null_and_empty_dim_list(); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpVarMeanCorrectionOutTest, InvalidDimensionListDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel test fails"; + } +#define TEST_KERNEL(INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE) \ + test_invalid_dimensions(); + +#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ + ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + + ET_FORALL_FLOAT_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#undef TEST_KERNEL +} + +TEST_F(OpVarMeanCorrectionOutTest, InvalidDTypeDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel test fails"; + } + TensorFactory tf_float; + TensorFactory tf_int; + + // clang-format off + Tensor self = tf_int.make( + {2, 3, 4}, + { + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + }); + // clang-format on + + Tensor var_out = tf_float.zeros({2, 3, 1}); + Tensor mean_out = tf_float.zeros({2, 3, 1}); + int64_t dims_1[1] = {2}; + optional> optional_dim_list{ArrayRef{dims_1, 1}}; + optional correction(1); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out)); +} + +TEST_F(OpVarMeanCorrectionOutTest, EmptyInput) { + TensorFactory tf; + + Tensor x = tf.make({2, 0, 3}, {}); + optional correction(1); + optional correction_zero(0); + + // empty dim list, correction=1, keepdim=true + optional> dim_list = ArrayRef{}; + Tensor var_out = tf.zeros({1, 1, 1}); + Tensor mean_out = tf.zeros({1, 1, 1}); + op_var_mean_correction_out( + x, dim_list, correction, /*keepdim=*/true, var_out, mean_out); + EXPECT_TENSOR_CLOSE(var_out, tf.make({1, 1, 1}, {NAN})); + EXPECT_TENSOR_CLOSE(mean_out, tf.make({1, 1, 1}, {NAN})); + + // empty dim list, correction=1, keepdim=false + var_out = tf.zeros({}); + mean_out = tf.zeros({}); + op_var_mean_correction_out( + x, dim_list, correction, /*keepdim=*/false, var_out, mean_out); + EXPECT_TENSOR_CLOSE(var_out, tf.make({}, {NAN})); + EXPECT_TENSOR_CLOSE(mean_out, tf.make({}, {NAN})); + + // reduce along the empty dim + int64_t dims1[1] = {1}; + dim_list = ArrayRef{dims1, 1}; + var_out = tf.zeros({2, 3}); + mean_out = tf.zeros({2, 3}); + op_var_mean_correction_out( + x, dim_list, correction, /*keepdim=*/false, var_out, mean_out); + EXPECT_TENSOR_CLOSE(var_out, tf.make({2, 3}, {NAN, NAN, NAN, NAN, NAN, NAN})); + EXPECT_TENSOR_CLOSE( + mean_out, tf.make({2, 3}, {NAN, NAN, NAN, NAN, NAN, NAN})); + + // reduce along a non-empty dim of an empty tensor + int64_t dims2[1] = {2}; + dim_list = ArrayRef{dims2, 1}; + var_out = tf.make({2, 0, 1}, {}); + mean_out = tf.make({2, 0, 1}, {}); + op_var_mean_correction_out( + x, dim_list, correction, /*keepdim=*/true, var_out, mean_out); + EXPECT_TENSOR_CLOSE(var_out, tf.make({2, 0, 1}, {})); + EXPECT_TENSOR_CLOSE(mean_out, tf.make({2, 0, 1}, {})); +} + +TEST_F(OpVarMeanCorrectionOutTest, DynamicShapeUpperBoundSameAsExpected) { + TensorFactory tf; + + Tensor x = tf.make({3, 2}, {0.49, 0.40, 0.56, 0.38, 0.49, 0.56}); + Tensor expected_var = tf.make({3}, {0.004050, 0.016200, 0.002450}); + Tensor expected_mean = tf.make({3}, {0.445, 0.47, 0.525}); + optional correction(1); + + Tensor var_out = + tf.zeros({3}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor mean_out = + tf.zeros({3}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + op_var_mean_correction_out( + x, + ArrayRef{1}, + correction, + /*keepdim=*/false, + var_out, + mean_out); + EXPECT_TENSOR_CLOSE(var_out, expected_var); + EXPECT_TENSOR_CLOSE(mean_out, expected_mean); +} + +TEST_F(OpVarMeanCorrectionOutTest, DynamicShapeUpperBoundLargerThanExpected) { + TensorFactory tf; + + Tensor x = tf.make({3, 2}, {0.49, 0.40, 0.56, 0.38, 0.49, 0.56}); + Tensor expected_var = tf.make({3}, {0.004050, 0.016200, 0.002450}); + Tensor expected_mean = tf.make({3}, {0.445, 0.47, 0.525}); + optional correction(1); + + Tensor var_out = + tf.zeros({10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor mean_out = + tf.zeros({10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + op_var_mean_correction_out( + x, + ArrayRef{1}, + correction, + /*keepdim=*/false, + var_out, + mean_out); + EXPECT_TENSOR_CLOSE(var_out, expected_var); + EXPECT_TENSOR_CLOSE(mean_out, expected_mean); +} + +TEST_F(OpVarMeanCorrectionOutTest, DynamicShapeUnbound) { + GTEST_SKIP() << "Dynamic shape unbound not supported"; + TensorFactory tf; + + Tensor x = tf.make({3, 2}, {0.49, 0.40, 0.56, 0.38, 0.49, 0.56}); + Tensor expected_var = tf.make({3}, {0.004050, 0.016200, 0.002450}); + Tensor expected_mean = tf.make({3}, {0.445, 0.47, 0.525}); + optional correction(1); + + Tensor var_out = + tf.zeros({1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); + Tensor mean_out = + tf.zeros({1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); + op_var_mean_correction_out( + x, + ArrayRef{1}, + correction, + /*keepdim=*/false, + var_out, + mean_out); + EXPECT_TENSOR_CLOSE(var_out, expected_var); + EXPECT_TENSOR_CLOSE(mean_out, expected_mean); +} + +TEST_F(OpVarMeanCorrectionOutTest, InfinityAndNANTest) { + TensorFactory tf; + // clang-format off + Tensor self = tf.make( + {2, 3, 4}, + { + 0, 1, 2, INFINITY, + INFINITY, -INFINITY, 1, 0, + NAN, INFINITY, -INFINITY, 2, + + NAN, NAN, 1, 0, + 0, INFINITY, NAN, 4, + 1, NAN, 3.14, 2, + }); + // clang-format on + + Tensor var_out = tf.zeros({2, 3, 1}); + Tensor mean_out = tf.zeros({2, 3, 1}); + int64_t dims[1] = {-1}; + optional> optional_dim_list{ArrayRef{dims, 1}}; + optional correction(1); + op_var_mean_correction_out( + self, + optional_dim_list, + correction, + /*keepdim=*/true, + var_out, + mean_out); + // All rows contain INFINITY or NAN, so var should be NAN for all rows. + // Mean can be INFINITY or NAN depending on input values, so only check var. + // clang-format off + EXPECT_TENSOR_CLOSE(var_out, tf.make( + {2, 3, 1}, + { + NAN, + NAN, + NAN, + + NAN, + NAN, + NAN, + })); + // clang-format on +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 60cc61c23db..abf6329248d 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -330,6 +330,7 @@ def define_common_targets(): _common_op_test("op_upsample_bilinear2d_aa_test", ["portable"]) _common_op_test("op_upsample_nearest2d_test", ["aten", "portable"]) _common_op_test("op_var_test", ["aten", "portable"]) + _common_op_test("op_var_mean_test", ["aten", "portable"]) _common_op_test("op_view_as_real_copy_test", ["aten", "portable"]) _common_op_test("op_view_copy_test", ["aten", "portable"]) _common_op_test("op_where_test", ["aten", "portable"]) diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index de5675b9098..c0af84a2477 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1332,6 +1332,15 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:reduce_util", ], ), + op_target( + name = "op_var_mean", + deps = [ + ":scalar_utils", + "//executorch/runtime/core/exec_aten/util:scalar_type_util", + "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/kernels/portable/cpu/util:reduce_util", + ], + ), op_target( name = "op_view_as_real_copy", deps = [