Skip to content

Commit 9ad1610

Browse files
committed
optimized: add NEON grid_sampler_2d.out and Vectorized<float> sum.IntList_out
Two new optimized CPU kernels registered alongside the existing optimized_kernels library. Both replace the portable reference kernel (still available as fallback for unsupported inputs) with a vectorized implementation that accumulates in fp32, avoiding the fp16 precision issues noted in pytorch#19117 for grid_sampler_2d bilinear. Measured end-to-end on a real depth model (Pixel 9, fp16 inputs, shapes representative of the model's hot path): | Op | Portable | This PR | Speedup | | -------------------------------- | -------- | ------- | ------- | | grid_sampler_2d.out | 17.3 ms | 3.4 ms | 5.1x | | sum.IntList_out (5 calls, total) | 3.0 ms | 0.56 ms | 5.4x | ### grid_sampler_2d.out aarch64 NEON, bilinear + zeros padding only. Processes 4 channels per iteration with a vectorized FMA chain. fp16 inputs are promoted to fp32 for weight computation and accumulation, then cast back on store — the portable kernel's fp16 weight subtractions like `(ix_se - ix)` otherwise suffer catastrophic cancellation. Unsupported modes and non-aarch64 targets delegate to the portable kernel. ### sum.IntList_out at::vec::Vectorized<float>-based implementation of the single-dim reduction fast path (both innermost-contiguous and strided cases). Cross-architecture SIMD via PyTorch's existing vector abstraction; accumulates in fp32 regardless of input dtype. Multi-dim reductions, dtype-converting reductions, and complex types delegate to portable. ### Integration - Sources added to OPTIMIZED_KERNELS_SRCS in build_variables.bzl and to OPTIMIZED_ATEN_OPS in op_registration_util.bzl. Single source of truth for both Buck and CMake builds. - optimized.yaml registers the ops with the standard opt_* naming convention used by sibling kernels. - kernels/optimized/CMakeLists.txt scopes the -march=armv8.2-a+fp16 flag to just op_grid_sampler_2d.cpp via set_source_files_properties, so x86_64 builds are unaffected. The kernel has #ifdef __aarch64__ guards and falls through to portable on non-arm64 targets.
1 parent 0f471a6 commit 9ad1610

6 files changed

Lines changed: 587 additions & 0 deletions

File tree

kernels/optimized/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ target_link_libraries(
7575
kernels_util_all_deps
7676
)
7777
target_compile_options(optimized_kernels PUBLIC ${_common_compile_options})
78+
79+
# op_grid_sampler_2d.cpp uses ARMv8.2-a+fp16 NEON intrinsics
80+
# (vcvt_f32_f16 / vld1_f16) when compiled for aarch64. Scope the extra
81+
# `-march` flag to just that source so non-arm64 targets (e.g. x86_64 on
82+
# Android) are unaffected — the kernel itself has `#ifdef __aarch64__`
83+
# guards and falls through to the portable kernel otherwise.
84+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64"
85+
OR ANDROID_ABI STREQUAL "arm64-v8a"
86+
)
87+
set_source_files_properties(
88+
${EXECUTORCH_ROOT}/kernels/optimized/cpu/op_grid_sampler_2d.cpp
89+
PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16"
90+
)
91+
endif()
92+
7893
# Build a library for _optimized_kernels_srcs
7994
#
8095
# optimized_ops_lib: Register optimized ops kernels into Executorch runtime
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Optimized grid_sampler_2d.out for CPU. On aarch64 this is a NEON-vectorized
10+
// implementation for the common (bilinear + zeros padding) case, processing
11+
// 4 channels at a time. Other modes — and non-aarch64 targets — fall through
12+
// to the portable kernel.
13+
//
14+
// fp16 inputs: all interior math (interpolation weights and corner
15+
// accumulation) is done in fp32. Loads/stores stay in the tensor's dtype.
16+
// Avoids catastrophic cancellation on `ix_se - ix`-style subtractions that
17+
// would otherwise make fp16 weights meaningless.
18+
19+
#include <executorch/runtime/kernel/kernel_includes.h>
20+
21+
#ifdef __aarch64__
22+
#include <arm_neon.h>
23+
#endif
24+
25+
#include <cmath>
26+
27+
namespace torch {
28+
namespace executor {
29+
namespace native {
30+
31+
using executorch::aten::ScalarType;
32+
using executorch::aten::Tensor;
33+
34+
// Portable kernel (same-op fallback). Both libs link into the same binary.
35+
Tensor& grid_sampler_2d_out(
36+
KernelRuntimeContext& ctx,
37+
const Tensor& input,
38+
const Tensor& grid,
39+
int64_t interpolation_mode,
40+
int64_t padding_mode,
41+
bool align_corners,
42+
Tensor& out);
43+
44+
#ifdef __aarch64__
45+
namespace {
46+
47+
// One output spatial location, all channels. fp32 path.
48+
inline void bilinear_all_channels_f32(
49+
const float* input_n,
50+
float* output_n,
51+
int C,
52+
int H_in,
53+
int W_in,
54+
int H_out,
55+
int W_out,
56+
int h_out,
57+
int w_out,
58+
float gx,
59+
float gy) {
60+
const int x0 = static_cast<int>(std::floor(gx));
61+
const int y0 = static_cast<int>(std::floor(gy));
62+
const int x1 = x0 + 1;
63+
const int y1 = y0 + 1;
64+
const float fx = gx - static_cast<float>(x0);
65+
const float fy = gy - static_cast<float>(y0);
66+
67+
const bool tl_v = static_cast<unsigned>(x0) < static_cast<unsigned>(W_in) &&
68+
static_cast<unsigned>(y0) < static_cast<unsigned>(H_in);
69+
const bool tr_v = static_cast<unsigned>(x1) < static_cast<unsigned>(W_in) &&
70+
static_cast<unsigned>(y0) < static_cast<unsigned>(H_in);
71+
const bool bl_v = static_cast<unsigned>(x0) < static_cast<unsigned>(W_in) &&
72+
static_cast<unsigned>(y1) < static_cast<unsigned>(H_in);
73+
const bool br_v = static_cast<unsigned>(x1) < static_cast<unsigned>(W_in) &&
74+
static_cast<unsigned>(y1) < static_cast<unsigned>(H_in);
75+
76+
const int off_tl = y0 * W_in + x0;
77+
const int off_tr = y0 * W_in + x1;
78+
const int off_bl = y1 * W_in + x0;
79+
const int off_br = y1 * W_in + x1;
80+
const int spatial_in = H_in * W_in;
81+
const int spatial_out = H_out * W_out;
82+
const int out_off = h_out * W_out + w_out;
83+
84+
const float32x4_t vw_tl = vdupq_n_f32((1.0f - fx) * (1.0f - fy));
85+
const float32x4_t vw_tr = vdupq_n_f32(fx * (1.0f - fy));
86+
const float32x4_t vw_bl = vdupq_n_f32((1.0f - fx) * fy);
87+
const float32x4_t vw_br = vdupq_n_f32(fx * fy);
88+
89+
int c = 0;
90+
for (; c + 3 < C; c += 4) {
91+
const float* p0 = input_n + (c + 0) * spatial_in;
92+
const float* p1 = input_n + (c + 1) * spatial_in;
93+
const float* p2 = input_n + (c + 2) * spatial_in;
94+
const float* p3 = input_n + (c + 3) * spatial_in;
95+
96+
float tl[4] = {0}, tr[4] = {0}, bl[4] = {0}, br[4] = {0};
97+
if (tl_v) {
98+
tl[0] = p0[off_tl]; tl[1] = p1[off_tl];
99+
tl[2] = p2[off_tl]; tl[3] = p3[off_tl];
100+
}
101+
if (tr_v) {
102+
tr[0] = p0[off_tr]; tr[1] = p1[off_tr];
103+
tr[2] = p2[off_tr]; tr[3] = p3[off_tr];
104+
}
105+
if (bl_v) {
106+
bl[0] = p0[off_bl]; bl[1] = p1[off_bl];
107+
bl[2] = p2[off_bl]; bl[3] = p3[off_bl];
108+
}
109+
if (br_v) {
110+
br[0] = p0[off_br]; br[1] = p1[off_br];
111+
br[2] = p2[off_br]; br[3] = p3[off_br];
112+
}
113+
114+
float32x4_t result = vmulq_f32(vw_tl, vld1q_f32(tl));
115+
result = vfmaq_f32(result, vw_tr, vld1q_f32(tr));
116+
result = vfmaq_f32(result, vw_bl, vld1q_f32(bl));
117+
result = vfmaq_f32(result, vw_br, vld1q_f32(br));
118+
119+
float res[4];
120+
vst1q_f32(res, result);
121+
output_n[(c + 0) * spatial_out + out_off] = res[0];
122+
output_n[(c + 1) * spatial_out + out_off] = res[1];
123+
output_n[(c + 2) * spatial_out + out_off] = res[2];
124+
output_n[(c + 3) * spatial_out + out_off] = res[3];
125+
}
126+
127+
// Scalar tail
128+
const float w_tl = (1.0f - fx) * (1.0f - fy);
129+
const float w_tr = fx * (1.0f - fy);
130+
const float w_bl = (1.0f - fx) * fy;
131+
const float w_br = fx * fy;
132+
for (; c < C; ++c) {
133+
const float* p = input_n + c * spatial_in;
134+
float v = 0.0f;
135+
if (tl_v) v += w_tl * p[off_tl];
136+
if (tr_v) v += w_tr * p[off_tr];
137+
if (bl_v) v += w_bl * p[off_bl];
138+
if (br_v) v += w_br * p[off_br];
139+
output_n[c * spatial_out + out_off] = v;
140+
}
141+
}
142+
143+
// fp16 path: loads/stores fp16, math in fp32.
144+
inline void bilinear_all_channels_f16(
145+
const __fp16* input_n,
146+
__fp16* output_n,
147+
int C,
148+
int H_in,
149+
int W_in,
150+
int H_out,
151+
int W_out,
152+
int h_out,
153+
int w_out,
154+
float gx,
155+
float gy) {
156+
const int x0 = static_cast<int>(std::floor(gx));
157+
const int y0 = static_cast<int>(std::floor(gy));
158+
const int x1 = x0 + 1;
159+
const int y1 = y0 + 1;
160+
const float fx = gx - static_cast<float>(x0);
161+
const float fy = gy - static_cast<float>(y0);
162+
163+
const bool tl_v = static_cast<unsigned>(x0) < static_cast<unsigned>(W_in) &&
164+
static_cast<unsigned>(y0) < static_cast<unsigned>(H_in);
165+
const bool tr_v = static_cast<unsigned>(x1) < static_cast<unsigned>(W_in) &&
166+
static_cast<unsigned>(y0) < static_cast<unsigned>(H_in);
167+
const bool bl_v = static_cast<unsigned>(x0) < static_cast<unsigned>(W_in) &&
168+
static_cast<unsigned>(y1) < static_cast<unsigned>(H_in);
169+
const bool br_v = static_cast<unsigned>(x1) < static_cast<unsigned>(W_in) &&
170+
static_cast<unsigned>(y1) < static_cast<unsigned>(H_in);
171+
172+
const int off_tl = y0 * W_in + x0;
173+
const int off_tr = y0 * W_in + x1;
174+
const int off_bl = y1 * W_in + x0;
175+
const int off_br = y1 * W_in + x1;
176+
const int spatial_in = H_in * W_in;
177+
const int spatial_out = H_out * W_out;
178+
const int out_off = h_out * W_out + w_out;
179+
180+
const float32x4_t vw_tl = vdupq_n_f32((1.0f - fx) * (1.0f - fy));
181+
const float32x4_t vw_tr = vdupq_n_f32(fx * (1.0f - fy));
182+
const float32x4_t vw_bl = vdupq_n_f32((1.0f - fx) * fy);
183+
const float32x4_t vw_br = vdupq_n_f32(fx * fy);
184+
185+
int c = 0;
186+
for (; c + 3 < C; c += 4) {
187+
const __fp16* p0 = input_n + (c + 0) * spatial_in;
188+
const __fp16* p1 = input_n + (c + 1) * spatial_in;
189+
const __fp16* p2 = input_n + (c + 2) * spatial_in;
190+
const __fp16* p3 = input_n + (c + 3) * spatial_in;
191+
192+
__fp16 tl[4] = {0}, tr[4] = {0}, bl[4] = {0}, br[4] = {0};
193+
if (tl_v) {
194+
tl[0] = p0[off_tl]; tl[1] = p1[off_tl];
195+
tl[2] = p2[off_tl]; tl[3] = p3[off_tl];
196+
}
197+
if (tr_v) {
198+
tr[0] = p0[off_tr]; tr[1] = p1[off_tr];
199+
tr[2] = p2[off_tr]; tr[3] = p3[off_tr];
200+
}
201+
if (bl_v) {
202+
bl[0] = p0[off_bl]; bl[1] = p1[off_bl];
203+
bl[2] = p2[off_bl]; bl[3] = p3[off_bl];
204+
}
205+
if (br_v) {
206+
br[0] = p0[off_br]; br[1] = p1[off_br];
207+
br[2] = p2[off_br]; br[3] = p3[off_br];
208+
}
209+
210+
const float32x4_t v_tl = vcvt_f32_f16(vld1_f16(tl));
211+
const float32x4_t v_tr = vcvt_f32_f16(vld1_f16(tr));
212+
const float32x4_t v_bl = vcvt_f32_f16(vld1_f16(bl));
213+
const float32x4_t v_br = vcvt_f32_f16(vld1_f16(br));
214+
215+
float32x4_t result = vmulq_f32(vw_tl, v_tl);
216+
result = vfmaq_f32(result, vw_tr, v_tr);
217+
result = vfmaq_f32(result, vw_bl, v_bl);
218+
result = vfmaq_f32(result, vw_br, v_br);
219+
220+
__fp16 res[4];
221+
vst1_f16(res, vcvt_f16_f32(result));
222+
output_n[(c + 0) * spatial_out + out_off] = res[0];
223+
output_n[(c + 1) * spatial_out + out_off] = res[1];
224+
output_n[(c + 2) * spatial_out + out_off] = res[2];
225+
output_n[(c + 3) * spatial_out + out_off] = res[3];
226+
}
227+
228+
const float w_tl = (1.0f - fx) * (1.0f - fy);
229+
const float w_tr = fx * (1.0f - fy);
230+
const float w_bl = (1.0f - fx) * fy;
231+
const float w_br = fx * fy;
232+
for (; c < C; ++c) {
233+
const __fp16* p = input_n + c * spatial_in;
234+
float v = 0.0f;
235+
if (tl_v) v += w_tl * static_cast<float>(p[off_tl]);
236+
if (tr_v) v += w_tr * static_cast<float>(p[off_tr]);
237+
if (bl_v) v += w_bl * static_cast<float>(p[off_bl]);
238+
if (br_v) v += w_br * static_cast<float>(p[off_br]);
239+
output_n[c * spatial_out + out_off] = static_cast<__fp16>(v);
240+
}
241+
}
242+
243+
template <typename SCALAR, typename SampleFn>
244+
void grid_sampler_2d_neon(
245+
const SCALAR* input,
246+
const SCALAR* grid,
247+
SCALAR* output,
248+
int N,
249+
int C,
250+
int H_in,
251+
int W_in,
252+
int H_out,
253+
int W_out,
254+
bool align_corners,
255+
SampleFn sample_fn) {
256+
const int spatial_in = H_in * W_in;
257+
const int spatial_out = H_out * W_out;
258+
259+
for (int n = 0; n < N; ++n) {
260+
const SCALAR* input_n = input + n * C * spatial_in;
261+
SCALAR* output_n = output + n * C * spatial_out;
262+
const SCALAR* grid_n = grid + n * H_out * W_out * 2;
263+
264+
for (int h = 0; h < H_out; ++h) {
265+
if (h + 1 < H_out) {
266+
__builtin_prefetch(grid_n + (h + 1) * W_out * 2, 0, 1);
267+
}
268+
for (int w = 0; w < W_out; ++w) {
269+
const int grid_off = (h * W_out + w) * 2;
270+
float gx = static_cast<float>(grid_n[grid_off]);
271+
float gy = static_cast<float>(grid_n[grid_off + 1]);
272+
if (align_corners) {
273+
gx = (gx + 1.0f) * (W_in - 1) * 0.5f;
274+
gy = (gy + 1.0f) * (H_in - 1) * 0.5f;
275+
} else {
276+
gx = (gx + 1.0f) * W_in * 0.5f - 0.5f;
277+
gy = (gy + 1.0f) * H_in * 0.5f - 0.5f;
278+
}
279+
sample_fn(
280+
input_n, output_n, C, H_in, W_in, H_out, W_out, h, w, gx, gy);
281+
}
282+
}
283+
}
284+
}
285+
286+
} // namespace
287+
#endif // __aarch64__
288+
289+
Tensor& opt_grid_sampler_2d_out(
290+
KernelRuntimeContext& ctx,
291+
const Tensor& input,
292+
const Tensor& grid,
293+
int64_t interpolation_mode,
294+
int64_t padding_mode,
295+
bool align_corners,
296+
Tensor& out) {
297+
// Only the bilinear + zeros-padding combination is accelerated. Everything
298+
// else — and any non-aarch64 target — delegates to the portable kernel.
299+
if (interpolation_mode != 0 || padding_mode != 0) {
300+
return grid_sampler_2d_out(
301+
ctx, input, grid, interpolation_mode, padding_mode, align_corners, out);
302+
}
303+
#ifndef __aarch64__
304+
return grid_sampler_2d_out(
305+
ctx, input, grid, interpolation_mode, padding_mode, align_corners, out);
306+
#else
307+
const int N = static_cast<int>(input.size(0));
308+
const int C = static_cast<int>(input.size(1));
309+
const int H_in = static_cast<int>(input.size(2));
310+
const int W_in = static_cast<int>(input.size(3));
311+
const int H_out = static_cast<int>(grid.size(1));
312+
const int W_out = static_cast<int>(grid.size(2));
313+
314+
if (input.scalar_type() == ScalarType::Float) {
315+
grid_sampler_2d_neon<float>(
316+
input.const_data_ptr<float>(),
317+
grid.const_data_ptr<float>(),
318+
out.mutable_data_ptr<float>(),
319+
N, C, H_in, W_in, H_out, W_out,
320+
align_corners,
321+
bilinear_all_channels_f32);
322+
return out;
323+
}
324+
if (input.scalar_type() == ScalarType::Half) {
325+
static_assert(sizeof(__fp16) == 2, "expected __fp16 == 2 bytes");
326+
grid_sampler_2d_neon<__fp16>(
327+
reinterpret_cast<const __fp16*>(input.const_data_ptr<uint16_t>()),
328+
reinterpret_cast<const __fp16*>(grid.const_data_ptr<uint16_t>()),
329+
reinterpret_cast<__fp16*>(out.mutable_data_ptr<uint16_t>()),
330+
N, C, H_in, W_in, H_out, W_out,
331+
align_corners,
332+
bilinear_all_channels_f16);
333+
return out;
334+
}
335+
// Any other dtype (e.g. Double, BFloat16): let portable handle it.
336+
return grid_sampler_2d_out(
337+
ctx, input, grid, interpolation_mode, padding_mode, align_corners, out);
338+
#endif
339+
}
340+
341+
} // namespace native
342+
} // namespace executor
343+
} // namespace torch

0 commit comments

Comments
 (0)