Skip to content

Commit e1ce260

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Add portable var_mean.correction_out kernel (pytorch#18775)
Summary: Add a new portable ExecuTorch kernel for var_mean.correction_out that computes both variance and mean in a single function call. Uses the same two-pass algorithm as var.correction_out — the mean computed in pass 1 is written to the output tensor instead of being discarded. Differential Revision: D100016876
1 parent c72d176 commit e1ce260

7 files changed

Lines changed: 849 additions & 0 deletions

File tree

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@
433433

434434
- op: var.correction_out
435435

436+
- op: var_mean.correction_out
437+
436438
- op: var.out
437439

438440
- op: view_as_real_copy.out
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <c10/util/irange.h>
10+
#include <cmath>
11+
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13+
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
14+
#include <executorch/runtime/kernel/kernel_includes.h>
15+
#include <executorch/runtime/platform/assert.h>
16+
17+
namespace torch {
18+
namespace executor {
19+
namespace native {
20+
namespace {
21+
22+
template <typename CTYPE_IN, typename CTYPE_OUT>
23+
void compute_var_mean(
24+
KernelRuntimeContext& ctx,
25+
const Tensor& in,
26+
Tensor& var_out,
27+
Tensor& mean_out,
28+
optional<ArrayRef<int64_t>> dim_list,
29+
const size_t num,
30+
const double denominator) {
31+
CTYPE_OUT* var_data = var_out.mutable_data_ptr<CTYPE_OUT>();
32+
CTYPE_OUT* mean_data = mean_out.mutable_data_ptr<CTYPE_OUT>();
33+
if (num == 0 || denominator <= 0) {
34+
for (const auto out_ix : c10::irange(var_out.numel())) {
35+
var_data[out_ix] = NAN;
36+
mean_data[out_ix] = NAN;
37+
}
38+
} else if (in.numel() > 0) {
39+
MapReduceOverDimListPlan plan(in, dim_list);
40+
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
41+
in, dim_list, var_out, [&](const auto begin, const auto end) {
42+
for (const auto out_ix : c10::irange(begin, end)) {
43+
// Pass 1: compute sum -> mean
44+
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
45+
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
46+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
47+
out_ix);
48+
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
49+
mean_data[out_ix] = mean;
50+
// Pass 2: compute sum of squared deviations
51+
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
52+
[mean](CTYPE_IN v) {
53+
return (
54+
(static_cast<CTYPE_OUT>(v) - mean) *
55+
(static_cast<CTYPE_OUT>(v) - mean));
56+
},
57+
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58+
out_ix);
59+
var_data[out_ix] = sum2 / denominator;
60+
}
61+
});
62+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
63+
}
64+
}
65+
66+
} // namespace
67+
68+
std::tuple<Tensor&, Tensor&> var_mean_correction_out(
69+
KernelRuntimeContext& ctx,
70+
const Tensor& in,
71+
optional<ArrayRef<int64_t>> dim_list,
72+
const optional<Scalar>& correction,
73+
bool keepdim,
74+
Tensor& out0,
75+
Tensor& out1) {
76+
(void)ctx;
77+
78+
std::tuple<Tensor&, Tensor&> ret_val(out0, out1);
79+
80+
ET_KERNEL_CHECK(
81+
ctx,
82+
check_reduction_args(in, dim_list, keepdim, {}, out0),
83+
InvalidArgument,
84+
ret_val);
85+
86+
ET_KERNEL_CHECK(
87+
ctx,
88+
check_reduction_args(in, dim_list, keepdim, {}, out1),
89+
InvalidArgument,
90+
ret_val);
91+
92+
ET_KERNEL_CHECK(
93+
ctx,
94+
resize_reduction_out(in, dim_list, keepdim, out0) == Error::Ok,
95+
InvalidArgument,
96+
ret_val);
97+
98+
ET_KERNEL_CHECK(
99+
ctx,
100+
resize_reduction_out(in, dim_list, keepdim, out1) == Error::Ok,
101+
InvalidArgument,
102+
ret_val);
103+
104+
static constexpr auto name = "var_mean.correction_out";
105+
106+
double correction_val = 1;
107+
if (correction.has_value()) {
108+
correction_val = utils::scalar_to<double>(correction.value());
109+
}
110+
111+
const size_t num = get_reduced_dim_product(in, dim_list);
112+
const double denom = num - correction_val;
113+
114+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
115+
ET_SWITCH_FLOATHBF16_TYPES(out0.scalar_type(), ctx, name, CTYPE_OUT, [&] {
116+
compute_var_mean<CTYPE_IN, CTYPE_OUT>(
117+
ctx, in, out0, out1, dim_list, num, denom);
118+
});
119+
});
120+
121+
return ret_val;
122+
}
123+
124+
} // namespace native
125+
} // namespace executor
126+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,11 @@
10151015
- arg_meta: null
10161016
kernel_name: torch::executor::var_correction_out
10171017

1018+
- op: var_mean.correction_out
1019+
kernels:
1020+
- arg_meta: null
1021+
kernel_name: torch::executor::var_mean_correction_out
1022+
10181023
- op: var.out
10191024
kernels:
10201025
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ set(all_test_sources
311311
"op_upsample_bilinear2d_aa_test.cpp"
312312
"op_upsample_nearest2d_test.cpp"
313313
"op_var_test.cpp"
314+
"op_var_mean_test.cpp"
314315
"op_view_as_real_copy_test.cpp"
315316
"op_view_copy_test.cpp"
316317
"op_where_test.cpp"

0 commit comments

Comments
 (0)