Skip to content

Commit ed5a100

Browse files
jgibson2claude
andcommitted
optimized/grid_sampler_2d: address review feedback
Three changes consolidated for review: 1. Move the forward declaration of grid_sampler_2d_bilinear_fp16_hw out of op_grid_sampler_2d.cpp into a new header kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.h. The function has external linkage (the dispatcher in op_grid_sampler_2d.cpp calls into it across translation units), and prior to this its definition site had no prior prototype visible — which trips -Wmissing-prototypes on build configurations that enable it. Both .cpp files now include the shared header. The function body stays in op_grid_sampler_2d_fp16_hw.cpp because that TU is the only one compiled with -march=armv8.2-a+fp16, so it cannot be inlined into a header. The header itself uses void* for input/output buffers and is fp16-free, so callers don't need the +fp16 march flag just to declare or call it. 2. Split the fp16 HW path into its own CMake target. Previously the -march=armv8.2-a+fp16 flag was scoped per-source-file via set_source_files_properties on the sole TU inside the optimized_kernels library. That works for a clean non-LTO build, but with ThinLTO or cross-TU optimizations the flag boundary becomes fuzzy and the fallback path in op_grid_sampler_2d.cpp could in principle be auto-vectorized into fp16 NEON instructions — exactly the SIGILL hazard the runtime dispatch is meant to prevent. Build the file as an OBJECT library (grid_sampler_2d_fp16_hw_impl) with target-scoped -march flag and link it into optimized_kernels via $<BUILD_LOCAL_INTERFACE:...> so the object code is baked into liboptimized_kernels.a at archive time and the OBJECT target is kept out of the install EXPORT set. Mirrors the existing buck `grid_sampler_2d_fp16_hw_impl` cxx_library. 3. Gate the optimized fast paths on input/grid/out dtype match. Each fast path assumes a single dtype across all three tensors: fp32 NEON path: data_ptr<float>() on all three fp16 HW path: void* pointers reinterpret_cast<__fp16*> on all three fp16 SW NEON: data_ptr<c10::Half>() on all three Until now the dispatcher gated only on input.scalar_type(). The reinterpret_casts in the fp16 HW kernel are particularly load-bearing because their behavior on a mismatched dtype would be silent corruption (reading int64/double bytes as __fp16 stride). The data_ptr<T>() runtime check exists but is not guaranteed in release builds. Add a dtypes_match clause at the top of the fast-path eligibility check that requires all three scalar types equal; fall back to the portable kernel otherwise. The portable kernel handles arbitrary dtype combinations correctly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4e98f5a commit ed5a100

5 files changed

Lines changed: 83 additions & 28 deletions

File tree

kernels/optimized/CMakeLists.txt

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,34 @@ target_link_libraries(
7777
target_compile_options(optimized_kernels PUBLIC ${_common_compile_options})
7878

7979
# op_grid_sampler_2d_fp16_hw.cpp uses hardware fp16 NEON intrinsics
80-
# (vcvt_f32_f16 / vld1_f16). Those are part of the ARMv8.2-a+fp16 extension and
81-
# raise SIGILL on chips without it. Scope the `-march` flag to just that
82-
# translation unit. The main op_grid_sampler_2d.cpp (which hosts the runtime
83-
# dispatcher via cpuinfo_has_arm_neon_fp16) and the fp16 software-convert path
84-
# stay on plain ARMv8 so they can run on any chip.
80+
# (vcvt_f32_f16 / vld1_f16). Those are part of the ARMv8.2-a+fp16 extension
81+
# and raise SIGILL on chips without it. Build it as a separate OBJECT
82+
# library so the `-march=armv8.2-a+fp16` flag stays strictly scoped to that
83+
# translation unit and never reaches the dispatcher / fallback code in
84+
# op_grid_sampler_2d.cpp (which would otherwise risk auto-vectorizing into
85+
# fp16 NEON instructions). The dispatcher chooses between this entry point
86+
# and the fp16 software-convert path at runtime via
87+
# cpuinfo_has_arm_neon_fp16(). Mirrors the buck `grid_sampler_2d_fp16_hw_impl`
88+
# library.
8589
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64" OR ANDROID_ABI STREQUAL
8690
"arm64-v8a"
8791
)
88-
set_source_files_properties(
92+
add_library(
93+
grid_sampler_2d_fp16_hw_impl OBJECT
8994
${EXECUTORCH_ROOT}/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp
90-
PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16"
95+
)
96+
target_compile_options(
97+
grid_sampler_2d_fp16_hw_impl
98+
PRIVATE -march=armv8.2-a+fp16 ${_common_compile_options}
99+
)
100+
target_link_libraries(grid_sampler_2d_fp16_hw_impl PRIVATE executorch_core)
101+
# BUILD_LOCAL_INTERFACE: object files are baked into optimized_kernels.a
102+
# at archive time, so this OBJECT target stays out of the install EXPORT
103+
# set and downstream consumers of the installed optimized_kernels need
104+
# no separate link entry.
105+
target_link_libraries(
106+
optimized_kernels
107+
PRIVATE $<BUILD_LOCAL_INTERFACE:grid_sampler_2d_fp16_hw_impl>
91108
)
92109
endif()
93110

kernels/optimized/cpu/op_grid_sampler_2d.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#ifdef __aarch64__
3434
#include <arm_neon.h>
3535
#include <cpuinfo.h>
36+
#include <executorch/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.h>
3637
#endif
3738

3839
#include <c10/util/Half.h>
@@ -56,25 +57,6 @@ Tensor& grid_sampler_2d_out(
5657
bool align_corners,
5758
Tensor& out);
5859

59-
#ifdef __aarch64__
60-
namespace opt_grid_sampler_2d_internal {
61-
// Declared in op_grid_sampler_2d_fp16_hw.cpp, compiled separately with
62-
// `-march=armv8.2-a+fp16`. Only safe to call when
63-
// cpuinfo_has_arm_neon_fp16() is true.
64-
void grid_sampler_2d_bilinear_fp16_hw(
65-
const void* input,
66-
const void* grid,
67-
void* output,
68-
int N,
69-
int C,
70-
int H_in,
71-
int W_in,
72-
int H_out,
73-
int W_out,
74-
bool align_corners);
75-
} // namespace opt_grid_sampler_2d_internal
76-
#endif
77-
7860
#ifdef __aarch64__
7961
namespace {
8062

@@ -361,7 +343,18 @@ Tensor& opt_grid_sampler_2d_out(
361343
tensor_is_contiguous(input) && tensor_is_contiguous(grid) &&
362344
tensor_is_contiguous(out);
363345

364-
if (interpolation_mode != 0 || padding_mode != 0 || !fast_eligible) {
346+
// The fast paths read input/grid and write out as a single dtype: float for
347+
// the fp32 NEON path, fp16 for both the fp16 HW path (which raw-casts the
348+
// void* pointers to __fp16*) and the SW fp16 NEON path (which uses
349+
// data_ptr<c10::Half>(), whose runtime dtype check is not guaranteed in
350+
// release builds). Reject any mixed-dtype call up front so none of those
351+
// unchecked casts can be reached with a mismatched buffer.
352+
const bool dtypes_match =
353+
input.scalar_type() == grid.scalar_type() &&
354+
input.scalar_type() == out.scalar_type();
355+
356+
if (interpolation_mode != 0 || padding_mode != 0 || !fast_eligible ||
357+
!dtypes_match) {
365358
return grid_sampler_2d_out(
366359
ctx, input, grid, interpolation_mode, padding_mode, align_corners, out);
367360
}

kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
#ifdef __aarch64__
2424

25+
#include <executorch/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.h>
26+
2527
#include <arm_neon.h>
2628
#include <cmath>
2729

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef __aarch64__
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
namespace opt_grid_sampler_2d_internal {
17+
18+
// Hardware-fp16 NEON bilinear + zeros-padding fast path. Defined in
19+
// op_grid_sampler_2d_fp16_hw.cpp, which is the only translation unit
20+
// compiled with `-march=armv8.2-a+fp16`. Only safe to call when
21+
// cpuinfo_has_arm_neon_fp16() reports true — see the runtime dispatcher
22+
// in op_grid_sampler_2d.cpp.
23+
//
24+
// Input/output buffers are passed as void* (raw uint16_t storage
25+
// interpreted as __fp16) so this header doesn't need <arm_neon.h> and
26+
// callers don't need the +fp16 march flag just to declare it.
27+
void grid_sampler_2d_bilinear_fp16_hw(
28+
const void* input,
29+
const void* grid,
30+
void* output,
31+
int N,
32+
int C,
33+
int H_in,
34+
int W_in,
35+
int H_out,
36+
int W_out,
37+
bool align_corners);
38+
39+
} // namespace opt_grid_sampler_2d_internal
40+
} // namespace native
41+
} // namespace executor
42+
} // namespace torch
43+
44+
#endif // __aarch64__

shim_et/xplat/executorch/build/build_variables.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ OPTIMIZED_KERNELS_SRCS = [
268268
"kernels/optimized/cpu/op_fft_r2c.cpp",
269269
"kernels/optimized/cpu/op_gelu.cpp",
270270
"kernels/optimized/cpu/op_grid_sampler_2d.cpp",
271-
"kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp",
272271
"kernels/optimized/cpu/op_le.cpp",
273272
"kernels/optimized/cpu/op_linear.cpp",
274273
"kernels/optimized/cpu/op_log_softmax.cpp",

0 commit comments

Comments
 (0)