Skip to content

Commit 5deccf8

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Add portable var_mean.correction_out kernel (#18775)
Summary: Add a new portable ExecuTorch kernel for var_mean.correction_out that computes both variance and mean in a single function call. Uses the same two-pass algorithm as var.correction_out — the mean computed in pass 1 is written to the output tensor instead of being discarded. Reviewed By: manuelcandales Differential Revision: D100016876
1 parent 2d13fae commit 5deccf8

7 files changed

Lines changed: 848 additions & 0 deletions

File tree

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@
433433

434434
- op: var.correction_out
435435

436+
- op: var_mean.correction_out
437+
436438
- op: var.out
437439

438440
- op: view_as_real_copy.out
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <c10/util/irange.h>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13+
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
14+
#include <executorch/runtime/kernel/kernel_includes.h>
15+
16+
namespace torch {
17+
namespace executor {
18+
namespace native {
19+
namespace {
20+
21+
template <typename CTYPE_IN, typename CTYPE_OUT>
22+
void compute_var_mean(
23+
KernelRuntimeContext& ctx,
24+
const Tensor& in,
25+
Tensor& var_out,
26+
Tensor& mean_out,
27+
optional<ArrayRef<int64_t>> dim_list,
28+
const size_t num,
29+
const double denominator) {
30+
CTYPE_OUT* var_data = var_out.mutable_data_ptr<CTYPE_OUT>();
31+
CTYPE_OUT* mean_data = mean_out.mutable_data_ptr<CTYPE_OUT>();
32+
if (num == 0 || denominator <= 0) {
33+
for (const auto out_ix : c10::irange(var_out.numel())) {
34+
var_data[out_ix] = NAN;
35+
mean_data[out_ix] = NAN;
36+
}
37+
} else if (in.numel() > 0) {
38+
MapReduceOverDimListPlan plan(in, dim_list);
39+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
40+
in, dim_list, var_out, [&](const auto begin, const auto end) {
41+
for (const auto out_ix : c10::irange(begin, end)) {
42+
// Pass 1: compute sum -> mean
43+
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
44+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
45+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
46+
out_ix);
47+
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
48+
mean_data[out_ix] = mean;
49+
// Pass 2: compute sum of squared deviations
50+
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
51+
[mean](CTYPE_IN v) {
52+
return (
53+
(static_cast<CTYPE_OUT>(v) - mean) *
54+
(static_cast<CTYPE_OUT>(v) - mean));
55+
},
56+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
57+
out_ix);
58+
var_data[out_ix] = sum2 / denominator;
59+
}
60+
});
61+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
62+
}
63+
}
64+
65+
} // namespace
66+
67+
std::tuple<Tensor&, Tensor&> var_mean_correction_out(
68+
KernelRuntimeContext& ctx,
69+
const Tensor& in,
70+
optional<ArrayRef<int64_t>> dim_list,
71+
const optional<Scalar>& correction,
72+
bool keepdim,
73+
Tensor& out0,
74+
Tensor& out1) {
75+
(void)ctx;
76+
77+
std::tuple<Tensor&, Tensor&> ret_val(out0, out1);
78+
79+
ET_KERNEL_CHECK(
80+
ctx,
81+
check_reduction_args(in, dim_list, keepdim, {}, out0),
82+
InvalidArgument,
83+
ret_val);
84+
85+
ET_KERNEL_CHECK(
86+
ctx,
87+
check_reduction_args(in, dim_list, keepdim, {}, out1),
88+
InvalidArgument,
89+
ret_val);
90+
91+
ET_KERNEL_CHECK(
92+
ctx,
93+
resize_reduction_out(in, dim_list, keepdim, out0) == Error::Ok,
94+
InvalidArgument,
95+
ret_val);
96+
97+
ET_KERNEL_CHECK(
98+
ctx,
99+
resize_reduction_out(in, dim_list, keepdim, out1) == Error::Ok,
100+
InvalidArgument,
101+
ret_val);
102+
103+
static constexpr auto name = "var_mean.correction_out";
104+
105+
double correction_val = 1;
106+
if (correction.has_value()) {
107+
correction_val = utils::scalar_to<double>(correction.value());
108+
}
109+
110+
const size_t num = get_reduced_dim_product(in, dim_list);
111+
const double denom = num - correction_val;
112+
113+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
114+
ET_SWITCH_FLOATHBF16_TYPES(out0.scalar_type(), ctx, name, CTYPE_OUT, [&] {
115+
compute_var_mean<CTYPE_IN, CTYPE_OUT>(
116+
ctx, in, out0, out1, dim_list, num, denom);
117+
});
118+
});
119+
120+
return ret_val;
121+
}
122+
123+
} // namespace native
124+
} // namespace executor
125+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,11 @@
10151015
- arg_meta: null
10161016
kernel_name: torch::executor::var_correction_out
10171017

1018+
- op: var_mean.correction_out
1019+
kernels:
1020+
- arg_meta: null
1021+
kernel_name: torch::executor::var_mean_correction_out
1022+
10181023
- op: var.out
10191024
kernels:
10201025
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ set(all_test_sources
311311
"op_upsample_bilinear2d_aa_test.cpp"
312312
"op_upsample_nearest2d_test.cpp"
313313
"op_var_test.cpp"
314+
"op_var_mean_test.cpp"
314315
"op_view_as_real_copy_test.cpp"
315316
"op_view_copy_test.cpp"
316317
"op_where_test.cpp"

0 commit comments

Comments
 (0)