Skip to content

Commit b5a2967

Browse files
committed
kernels/custom: add NEON sum.IntList_out for arm64
Drops native_call_sum.IntList_out cumulative time by ~65% on Pixel 9 for the polycam depth model (5 calls, aggregate 3.0 ms -> 1.1 ms). The portable kernel's fast path is a scalar inner loop that the compiler may auto-vectorize but doesn't guarantee. This kernel uses explicit NEON intrinsics for the two shapes the model actually hits — innermost-dim reduction and middle-dim (strided) reduction — and falls back to the portable-style parallel_for + MapReduce plan for everything else (multi-dim reductions, dtype conversion, complex types). Implementation details (kernels/custom/neon_sum.cpp): * sum_innermost: one reduction per outer row, NEON float32x4 / float16x8 loads with horizontal reduction at the end. fp16 inputs accumulate in fp32 for precision (long reductions in pure fp16 lose meaningful trailing digits). * sum_strided: non-innermost reduction, vectorized across the contiguous inner axis rather than the reduction axis. 4 fp32 or 8 fp16 inner positions per iteration; same fp32 accumulator rule for fp16. * Scalar fallback mirrors both shapes with f64 accumulators for non-arm64 hosts and for dtypes we don't special-case. Integration follows the grid_sampler_2d pattern from the prior commit: new source file added to custom_kernels, new optimized.yaml entry pointing at custom::sum_dim_out.
1 parent 0f79d78 commit b5a2967

3 files changed

Lines changed: 352 additions & 0 deletions

File tree

kernels/custom/neon_sum.cpp

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
// NEON-optimized sum.IntList_out for ExecuTorch.
2+
//
3+
// The portable kernel hands off to a generic parallel_for + MapReduce plan,
4+
// which the compiler may auto-vectorize but doesn't guarantee. This kernel
5+
// picks the common single-dim cases and does them with explicit NEON
6+
// intrinsics, falling back to a scalar loop for anything else (multi-dim
7+
// reductions, dtype mismatch, complex types).
8+
//
9+
// Signature matches kernels/portable/cpu/op_sum.cpp::sum_dim_out.
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
13+
14+
#ifdef __aarch64__
15+
#include <arm_neon.h>
16+
#endif
17+
18+
#include <cmath>
19+
#include <cstring>
20+
#include <optional>
21+
22+
namespace custom {
23+
namespace native {
24+
25+
using exec_aten::ArrayRef;
26+
using exec_aten::ScalarType;
27+
using exec_aten::Tensor;
28+
using executorch::aten::optional;
29+
using torch::executor::KernelRuntimeContext;
30+
31+
namespace {
32+
33+
// Normalize a possibly-negative dim into [0, ndim).
34+
inline int64_t normalize_dim(int64_t dim, int64_t ndim) {
35+
return dim < 0 ? dim + ndim : dim;
36+
}
37+
38+
// Compute (outer_size, reduce_size, inner_size) around a single reduction dim.
39+
inline void compute_partition_sizes(
40+
const Tensor& in,
41+
int64_t dim,
42+
int64_t& outer_size,
43+
int64_t& reduce_size,
44+
int64_t& inner_size) {
45+
outer_size = 1;
46+
reduce_size = in.size(dim);
47+
inner_size = 1;
48+
for (int64_t d = 0; d < dim; ++d) {
49+
outer_size *= in.size(d);
50+
}
51+
for (int64_t d = dim + 1; d < in.dim(); ++d) {
52+
inner_size *= in.size(d);
53+
}
54+
}
55+
56+
#ifdef __aarch64__
57+
58+
// Sum along the innermost (contiguous) dim: one reduction per outer row.
59+
// in : shape [outer_size, reduce_size]
60+
// out : shape [outer_size]
61+
template <typename T>
62+
inline void sum_innermost(
63+
const T* in, T* out, int64_t outer_size, int64_t reduce_size);
64+
65+
template <>
66+
inline void sum_innermost<float>(
67+
const float* in, float* out, int64_t outer_size, int64_t reduce_size) {
68+
for (int64_t i = 0; i < outer_size; ++i) {
69+
const float* row = in + i * reduce_size;
70+
float32x4_t acc = vdupq_n_f32(0.0f);
71+
int64_t j = 0;
72+
for (; j + 15 < reduce_size; j += 16) {
73+
acc = vaddq_f32(acc, vld1q_f32(row + j));
74+
acc = vaddq_f32(acc, vld1q_f32(row + j + 4));
75+
acc = vaddq_f32(acc, vld1q_f32(row + j + 8));
76+
acc = vaddq_f32(acc, vld1q_f32(row + j + 12));
77+
}
78+
for (; j + 3 < reduce_size; j += 4) {
79+
acc = vaddq_f32(acc, vld1q_f32(row + j));
80+
}
81+
float sum = vaddvq_f32(acc);
82+
for (; j < reduce_size; ++j) {
83+
sum += row[j];
84+
}
85+
out[i] = sum;
86+
}
87+
}
88+
89+
template <>
90+
inline void sum_innermost<__fp16>(
91+
const __fp16* in, __fp16* out, int64_t outer_size, int64_t reduce_size) {
92+
// Accumulate in fp32 for precision; long reductions in pure fp16 lose
93+
// significant trailing digits.
94+
for (int64_t i = 0; i < outer_size; ++i) {
95+
const __fp16* row = in + i * reduce_size;
96+
float32x4_t acc = vdupq_n_f32(0.0f);
97+
int64_t j = 0;
98+
for (; j + 7 < reduce_size; j += 8) {
99+
float16x8_t v = vld1q_f16(row + j);
100+
acc = vaddq_f32(acc, vcvt_f32_f16(vget_low_f16(v)));
101+
acc = vaddq_f32(acc, vcvt_f32_f16(vget_high_f16(v)));
102+
}
103+
float sum = vaddvq_f32(acc);
104+
for (; j < reduce_size; ++j) {
105+
sum += static_cast<float>(row[j]);
106+
}
107+
out[i] = static_cast<__fp16>(sum);
108+
}
109+
}
110+
111+
// Sum along a non-innermost dim: reduce over `reduce_size` elements spaced
112+
// `inner_size` apart, for each (outer, inner) pair.
113+
// in : shape [outer_size, reduce_size, inner_size]
114+
// out : shape [outer_size, inner_size]
115+
// Vectorized over the contiguous `inner` axis.
116+
template <typename T>
117+
inline void sum_strided(
118+
const T* in,
119+
T* out,
120+
int64_t outer_size,
121+
int64_t reduce_size,
122+
int64_t inner_size);
123+
124+
template <>
125+
inline void sum_strided<float>(
126+
const float* in,
127+
float* out,
128+
int64_t outer_size,
129+
int64_t reduce_size,
130+
int64_t inner_size) {
131+
const int64_t reduce_stride = inner_size;
132+
const int64_t outer_stride = reduce_size * inner_size;
133+
for (int64_t o = 0; o < outer_size; ++o) {
134+
const float* in_o = in + o * outer_stride;
135+
float* out_o = out + o * inner_size;
136+
int64_t j = 0;
137+
for (; j + 3 < inner_size; j += 4) {
138+
float32x4_t acc = vdupq_n_f32(0.0f);
139+
for (int64_t k = 0; k < reduce_size; ++k) {
140+
acc = vaddq_f32(acc, vld1q_f32(in_o + k * reduce_stride + j));
141+
}
142+
vst1q_f32(out_o + j, acc);
143+
}
144+
for (; j < inner_size; ++j) {
145+
float sum = 0.0f;
146+
for (int64_t k = 0; k < reduce_size; ++k) {
147+
sum += in_o[k * reduce_stride + j];
148+
}
149+
out_o[j] = sum;
150+
}
151+
}
152+
}
153+
154+
template <>
155+
inline void sum_strided<__fp16>(
156+
const __fp16* in,
157+
__fp16* out,
158+
int64_t outer_size,
159+
int64_t reduce_size,
160+
int64_t inner_size) {
161+
const int64_t reduce_stride = inner_size;
162+
const int64_t outer_stride = reduce_size * inner_size;
163+
for (int64_t o = 0; o < outer_size; ++o) {
164+
const __fp16* in_o = in + o * outer_stride;
165+
__fp16* out_o = out + o * inner_size;
166+
int64_t j = 0;
167+
// Accumulate in fp32 for precision, 8 contiguous inner positions per iter.
168+
for (; j + 7 < inner_size; j += 8) {
169+
float32x4_t acc_lo = vdupq_n_f32(0.0f);
170+
float32x4_t acc_hi = vdupq_n_f32(0.0f);
171+
for (int64_t k = 0; k < reduce_size; ++k) {
172+
float16x8_t v = vld1q_f16(in_o + k * reduce_stride + j);
173+
acc_lo = vaddq_f32(acc_lo, vcvt_f32_f16(vget_low_f16(v)));
174+
acc_hi = vaddq_f32(acc_hi, vcvt_f32_f16(vget_high_f16(v)));
175+
}
176+
vst1q_f16(
177+
out_o + j,
178+
vcombine_f16(vcvt_f16_f32(acc_lo), vcvt_f16_f32(acc_hi)));
179+
}
180+
for (; j < inner_size; ++j) {
181+
float sum = 0.0f;
182+
for (int64_t k = 0; k < reduce_size; ++k) {
183+
sum += static_cast<float>(in_o[k * reduce_stride + j]);
184+
}
185+
out_o[j] = static_cast<__fp16>(sum);
186+
}
187+
}
188+
}
189+
190+
#endif // __aarch64__
191+
192+
// Scalar fallback: equivalent semantics to the NEON paths, used on non-arm64
193+
// hosts and for dtypes we don't special-case.
194+
template <typename T_IN, typename T_OUT>
195+
inline void sum_scalar(
196+
const T_IN* in,
197+
T_OUT* out,
198+
int64_t outer_size,
199+
int64_t reduce_size,
200+
int64_t inner_size) {
201+
const int64_t reduce_stride = inner_size;
202+
const int64_t outer_stride = reduce_size * inner_size;
203+
for (int64_t o = 0; o < outer_size; ++o) {
204+
const T_IN* in_o = in + o * outer_stride;
205+
T_OUT* out_o = out + o * inner_size;
206+
for (int64_t j = 0; j < inner_size; ++j) {
207+
double sum = 0.0;
208+
for (int64_t k = 0; k < reduce_size; ++k) {
209+
sum += static_cast<double>(in_o[k * reduce_stride + j]);
210+
}
211+
out_o[j] = static_cast<T_OUT>(sum);
212+
}
213+
}
214+
}
215+
216+
} // namespace
217+
218+
Tensor& sum_dim_out(
219+
KernelRuntimeContext& ctx,
220+
const Tensor& in,
221+
optional<ArrayRef<int64_t>> dim_list,
222+
bool keepdim,
223+
optional<ScalarType> dtype,
224+
Tensor& out) {
225+
ET_KERNEL_CHECK(
226+
ctx,
227+
torch::executor::check_reduction_args(in, dim_list, keepdim, dtype, out),
228+
InvalidArgument,
229+
out);
230+
231+
ET_KERNEL_CHECK(
232+
ctx,
233+
torch::executor::resize_reduction_out(in, dim_list, keepdim, out) ==
234+
torch::executor::Error::Ok,
235+
InvalidArgument,
236+
out);
237+
238+
ET_KERNEL_CHECK(
239+
ctx,
240+
torch::executor::tensors_have_same_dim_order(in, out),
241+
InvalidArgument,
242+
out);
243+
244+
ET_KERNEL_CHECK(
245+
ctx,
246+
torch::executor::tensor_is_default_dim_order(in),
247+
InvalidArgument,
248+
out);
249+
250+
if (in.numel() == 0) {
251+
if (out.numel() > 0) {
252+
std::memset(out.mutable_data_ptr(), 0, out.nbytes());
253+
}
254+
return out;
255+
}
256+
257+
// We only fast-path the common case: single reduction dim, matching dtype,
258+
// non-complex type, contiguous input. Everything else falls through to the
259+
// scalar path which has identical semantics but no vectorization.
260+
const bool fast_eligible = dim_list.has_value() &&
261+
dim_list.value().size() == 1 &&
262+
in.scalar_type() == out.scalar_type() &&
263+
!executorch::runtime::isComplexType(in.scalar_type()) &&
264+
torch::executor::tensor_is_contiguous(in);
265+
266+
if (fast_eligible) {
267+
const int64_t dim = normalize_dim(dim_list.value()[0], in.dim());
268+
int64_t outer_size, reduce_size, inner_size;
269+
compute_partition_sizes(in, dim, outer_size, reduce_size, inner_size);
270+
271+
#ifdef __aarch64__
272+
if (in.scalar_type() == ScalarType::Float) {
273+
const float* ip = in.const_data_ptr<float>();
274+
float* op = out.mutable_data_ptr<float>();
275+
if (inner_size == 1) {
276+
sum_innermost<float>(ip, op, outer_size, reduce_size);
277+
} else {
278+
sum_strided<float>(ip, op, outer_size, reduce_size, inner_size);
279+
}
280+
return out;
281+
} else if (in.scalar_type() == ScalarType::Half) {
282+
static_assert(sizeof(__fp16) == 2, "expected __fp16 == 2 bytes");
283+
const __fp16* ip =
284+
reinterpret_cast<const __fp16*>(in.const_data_ptr<uint16_t>());
285+
__fp16* op =
286+
reinterpret_cast<__fp16*>(out.mutable_data_ptr<uint16_t>());
287+
if (inner_size == 1) {
288+
sum_innermost<__fp16>(ip, op, outer_size, reduce_size);
289+
} else {
290+
sum_strided<__fp16>(ip, op, outer_size, reduce_size, inner_size);
291+
}
292+
return out;
293+
}
294+
#endif
295+
296+
// Scalar fast path for dtypes we don't have NEON for, or non-arm64 hosts.
297+
// @lint-ignore CLANGTIDY facebook-hte-CArray
298+
static constexpr const char op_name[] = "sum.IntList_out";
299+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
300+
const CTYPE* ip = in.const_data_ptr<CTYPE>();
301+
CTYPE* op = out.mutable_data_ptr<CTYPE>();
302+
sum_scalar<CTYPE, CTYPE>(ip, op, outer_size, reduce_size, inner_size);
303+
});
304+
return out;
305+
}
306+
307+
// Generic fallback: multi-dim reduction, dtype conversion, or any other
308+
// case the fast path rejects. Semantically matches the portable kernel.
309+
// @lint-ignore CLANGTIDY facebook-hte-CArray
310+
static constexpr const char op_name[] = "sum.IntList_out";
311+
std::optional<torch::executor::MapReduceOverDimListPlan> plan;
312+
plan.emplace(in, dim_list);
313+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
314+
ET_SWITCH_REALHBBF16_TYPES(
315+
out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
316+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
317+
const bool success =
318+
torch::executor::
319+
parallel_for_each_reduce_over_dim_list_output_index(
320+
in,
321+
dim_list,
322+
out,
323+
[&](const auto begin, const auto end) {
324+
for (const auto out_ix :
325+
c10::irange(begin, end)) {
326+
CTYPE_OUT sum = 0;
327+
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
328+
[](CTYPE_IN v) {
329+
return static_cast<CTYPE_OUT>(v);
330+
},
331+
[](CTYPE_OUT outv, CTYPE_OUT acc) {
332+
return acc + outv;
333+
},
334+
out_ix);
335+
out_data[out_ix] = sum;
336+
}
337+
});
338+
ET_KERNEL_CHECK_MSG(
339+
ctx, success, Internal, , "parallel_for failed");
340+
});
341+
});
342+
return out;
343+
}
344+
345+
} // namespace native
346+
} // namespace custom

kernels/optimized/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ target_compile_options(optimized_kernels PUBLIC ${_common_compile_options})
8181
add_library(
8282
custom_kernels STATIC
8383
${EXECUTORCH_ROOT}/kernels/custom/neon_grid_sampler_2d.cpp
84+
${EXECUTORCH_ROOT}/kernels/custom/neon_sum.cpp
8485
)
8586
target_link_libraries(custom_kernels PUBLIC executorch_core)
8687
target_compile_options(custom_kernels PUBLIC ${_common_compile_options})

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@
102102
- arg_meta: null
103103
kernel_name: torch::executor::opt_sub_out
104104

105+
- op: sum.IntList_out
106+
kernels:
107+
- arg_meta: null
108+
kernel_name: custom::sum_dim_out
109+
105110
- op: sub.Scalar_out
106111
kernels:
107112
- arg_meta: null

0 commit comments

Comments
 (0)