Skip to content

Commit 851cffb

Browse files
Fix missing check (#19340)
Missing dimension check which was breaking test.
1 parent 3c4ec8f commit 851cffb

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

kernels/optimized/cpu/op_grid_sampler_2d.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,10 @@ Tensor& opt_grid_sampler_2d_out(
338338
// The NEON paths index input/grid/out directly assuming a contiguous NCHW
339339
// default-dim-order layout — no use of .strides() or .dim_order(). Fall
340340
// back to portable for anything else.
341-
const bool fast_eligible = tensor_is_default_dim_order(input) &&
342-
tensor_is_default_dim_order(grid) && tensor_is_default_dim_order(out) &&
343-
tensor_is_contiguous(input) && tensor_is_contiguous(grid) &&
344-
tensor_is_contiguous(out);
341+
const bool fast_eligible = input.dim() == 4 && grid.dim() == 4 &&
342+
tensor_is_default_dim_order(input) && tensor_is_default_dim_order(grid) &&
343+
tensor_is_default_dim_order(out) && tensor_is_contiguous(input) &&
344+
tensor_is_contiguous(grid) && tensor_is_contiguous(out);
345345

346346
// The fast paths read input/grid and write out as a single dtype: float for
347347
// the fp32 NEON path, fp16 for both the fp16 HW path (which raw-casts the

0 commit comments

Comments
 (0)