Skip to content

Commit b2bdc94

Browse files
committed
optimized: fallback to portable grid_sampler_2d for non-default layouts
The NEON fast path indexes input/grid/out directly assuming contiguous NCHW default-dim-order layout — no use of .strides() or .dim_order(). If the caller passes anything else (NHWC, transposed, strided, channels- last), we'd read wrong memory and silently produce garbage output. Add the same check pattern op_sum.cpp already uses at L150-151: tensor_is_default_dim_order + tensor_is_contiguous on input, grid, and out. If any fails, delegate to the portable kernel (which handles arbitrary strides / dim orders correctly via .strides()). No perf impact on the hot path — the checks are a handful of scalar comparisons run once per call, and the common polycam depth model case is already default-contiguous so the fast path is still taken.
1 parent 8721bfa commit b2bdc94

1 file changed

Lines changed: 13 additions & 2 deletions

File tree

kernels/optimized/cpu/op_grid_sampler_2d.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,20 @@ Tensor& opt_grid_sampler_2d_out(
294294
int64_t padding_mode,
295295
bool align_corners,
296296
Tensor& out) {
297+
// The NEON path indexes input/grid/out directly assuming a contiguous NCHW
298+
// default-dim-order layout — no use of .strides() or .dim_order(). If the
299+
// caller passes anything else, fall back to portable (which does handle
300+
// arbitrary strides and dim orders correctly). These are cheap checks.
301+
const bool fast_eligible = tensor_is_default_dim_order(input) &&
302+
tensor_is_default_dim_order(grid) &&
303+
tensor_is_default_dim_order(out) &&
304+
tensor_is_contiguous(input) &&
305+
tensor_is_contiguous(grid) &&
306+
tensor_is_contiguous(out);
307+
297308
// Only the bilinear + zeros-padding combination is accelerated. Everything
298-
// else — and any non-aarch64 target — delegates to the portable kernel.
299-
if (interpolation_mode != 0 || padding_mode != 0) {
309+
// else — non-default layout, any non-aarch64 target — delegates to portable.
310+
if (interpolation_mode != 0 || padding_mode != 0 || !fast_eligible) {
300311
return grid_sampler_2d_out(
301312
ctx, input, grid, interpolation_mode, padding_mode, align_corners, out);
302313
}

0 commit comments

Comments
 (0)