Commit b384173
optimized: add grid_sampler_2d.out (NEON) and sum.IntList_out (Vectorized<float>) (#19119)
## Summary
Two new optimized CPU kernels registered alongside the existing
`optimized_kernels` library. Both replace the portable reference kernel
(still available as fallback for unsupported inputs) with vectorized
implementations that accumulate in fp32, which also sidesteps the fp16
precision issue noted in #19117 for `grid_sampler_2d` bilinear.
Measured end-to-end on a real depth model (Pixel 9 / arm64-v8a, fp16
inputs, shapes representative of the model's hot path):
| Op | Portable | This PR | Speedup |
|---|---:|---:|---:|
| `grid_sampler_2d.out` | 17.3 ms | **3.4 ms** | **5.1×** |
| `sum.IntList_out` (5 calls, aggregate) | 3.0 ms | **0.56 ms** |
**5.4×** |
## `grid_sampler_2d.out`
aarch64 NEON, bilinear + zeros padding only (the dominant mode for depth
/ MVS / spatial transformer networks). Processes 4 channels per
iteration with a vectorized FMA chain. fp16 inputs are promoted to fp32
for weight computation and accumulation, cast back on store — the
portable kernel's fp16 weight subtractions like `(ix_se - ix)` otherwise
suffer catastrophic cancellation (same concern as #19117). Unsupported
modes and non-aarch64 targets delegate to the portable kernel.
## `sum.IntList_out`
`at::vec::Vectorized<float>`-based implementation of the single-dim
reduction fast path (both innermost-contiguous and strided cases).
Cross-architecture SIMD via PyTorch's existing vector abstraction;
always accumulates in fp32 regardless of input dtype. Multi-dim
reductions, dtype-converting reductions, and complex types delegate to
portable.
## Integration
- Sources added to `OPTIMIZED_KERNELS_SRCS` in `build_variables.bzl` and
to `OPTIMIZED_ATEN_OPS` in `op_registration_util.bzl`. Single source of
truth for both Buck and CMake builds.
- `optimized.yaml` registers the ops with the standard `opt_*` naming
convention used by sibling kernels.
- `kernels/optimized/CMakeLists.txt` scopes the `-march=armv8.2-a+fp16`
flag to just `op_grid_sampler_2d.cpp` via `set_source_files_properties`,
so x86_64 builds are unaffected. The kernel has `#ifdef __aarch64__`
guards and falls through to portable on non-arm64 targets.
## Test plan
- [x] Builds cleanly for Android arm64-v8a, Android x86_64 (via
`scripts/build_android_library.sh`), and host (macOS / Apple Clang 21).
- [x] Existing `kernels/test/op_grid_sampler_2d_test.cpp` and
`op_sum_test.cpp` unit tests continue to pass — both target the
`aten::sum_outf` / `aten::grid_sampler_2d_outf` codegen-dispatched entry
points, so they automatically exercise the optimized kernels when
linked.
- [x] Numerical verification against an fp32 reference (run portable in
fp32, cast to fp16) on the shapes the polycam depth model uses — all
cases pass within fp16 ULP.
- [x] End-to-end Pixel 9 latency on a representative trained model
matches the handwritten-NEON reference implementation to within
run-to-run noise while producing more accurate fp16 outputs (fp32
accumulation).
Candidate successor to #19117 for the grid_sampler half — applies the
same precision fix but at the optimized-kernel layer, so callers who
link `optimized_ops_lib` get both the correctness fix and the speedup.
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>1 parent d767516 commit b384173
10 files changed
Lines changed: 969 additions & 0 deletions
File tree
- kernels/optimized
- cpu
- shim_et/xplat/executorch
- build
- codegen
- kernels/optimized
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
75 | 75 | | |
76 | 76 | | |
77 | 77 | | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
78 | 110 | | |
79 | 111 | | |
80 | 112 | | |
| |||
0 commit comments