File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -338,10 +338,10 @@ Tensor& opt_grid_sampler_2d_out(
338338 // The NEON paths index input/grid/out directly assuming a contiguous NCHW
339339 // default-dim-order layout — no use of .strides() or .dim_order(). Fall
340340 // back to portable for anything else.
341- const bool fast_eligible = tensor_is_default_dim_order ( input) &&
342- tensor_is_default_dim_order (grid ) && tensor_is_default_dim_order (out ) &&
343- tensor_is_contiguous (input ) && tensor_is_contiguous (grid ) &&
344- tensor_is_contiguous (out);
341+ const bool fast_eligible = input. dim () == 4 && grid. dim () == 4 &&
342+ tensor_is_default_dim_order (input ) && tensor_is_default_dim_order (grid ) &&
343+ tensor_is_default_dim_order (out ) && tensor_is_contiguous (input ) &&
344+ tensor_is_contiguous (grid) && tensor_is_contiguous ( out);
345345
346346 // The fast paths read input/grid and write out as a single dtype: float for
347347 // the fp32 NEON path, fp16 for both the fp16 HW path (which raw-casts the
You can’t perform that action at this time.
0 commit comments