Skip to content

Commit b7a8c24

Browse files
melkap01-Armhariharans29Copilot
authored
[CPU] GridSample operator performance improvement on bilinear interpolation… (microsoft#27359)
### Description This change optimises the GridSample operator in onnxrt. 1- For the GridSample nodes having similar characteristics the camera based 3D object detection model in MLPerf Automotive space, transforming the input to output coordinates with : 2D interpolation mode = linear with padding mode = zeros, align_corners = 0, fast path added. Linear interpolation: For each (x, y), the code locates the four surrounding integer pixel centers: (x1, y1) = (floor(x), floor(y)) (top-left) (x2, y2) = (x1 + 1, y1 + 1) (bottom-right) The interpolation weights reflect the fractional positions: dx1 = x - x1, dx2 = x2 - x dy1 = y - y1, dy2 = y2 - y The resulting value is the bilinear blend dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22) where p11…p22 are the input pixels at those four neighbor coordinates. Padding mode = zeros: Any neighbor index that falls outside [0, W_in-1] × [0, H_in-1] contributes 0 to the interpolation. Each output pixel (oy, ox) carries normalized coordinates (nx, ny) in [-1, 1]. With align_corners=0, nx = -1 corresponds to a location half a pixel before the leftmost input column (i.e., x = -0.5), and nx = 1 corresponds to half a pixel beyond the rightmost column (x = W_in - 0.5). Same idea vertically for ny. Fast Path Optimisation : The implementation can precompute all neighbor indices/weights for each output pixel once (they depend only on the grid), then reuse them for every channel. Previously indices and weights were calculated inside the loops which can be as much as (H_out*W_out like 20,000 per batch element in one case) x 32 Channels. 2-optional ARM NEON vectorization added : - vld1_f32(ptr): loads two contiguous float values into a float32x2_t. Used to read the top and bottom neighbor     pairs ([p11, p12], [p21, p22]).   - vcombine_f32(low, high): concatenates two float32x2_t values into one float32x4_t, giving [p11, p12, p21, p22].   - vdup_n_f32(val): duplicates a scalar float into both lanes of a float32x2_t.   - vset_lane_f32(val, vec, lane): writes val into the specified lane of a float32x2_t, letting us form [w11, w12] and     [w21, w22].   - vmulq_f32(a, b): multiplies two float32x4_t vectors element-wise (neighbor pixels × weights).   - vget_low_f32(vec) / vget_high_f32(vec): extract the lower or upper 2 lanes from a float32x4_t as float32x2_t.   - vadd_f32(a, b): adds two float32x2_t vectors element-wise (forming partial sums).   - vpadd_f32(a, b): performs pairwise adds within and across two float32x2_t vectors, collapsing four elements down     to two.   - vget_lane_f32(vec, lane): reads a scalar from a specific lane, giving the final interpolated value. Most of the performance uplift coming from the 1st optimisation. 2nd optimisation using NEON intrinsics still contributes but not as much as the 1st one. Overall performance improvement : 1 thread : <img width="902" height="766" alt="image" src="https://github.com/user-attachments/assets/d1fadc6d-370d-4750-baee-1123c7d18af3" /> 2 threads: <img width="902" height="766" alt="image" src="https://github.com/user-attachments/assets/69c86fd6-815a-4b52-8f86-615f1c99bf0a" /> ### Motivation and Context The fast path handles denormalisation of the linear coordinates and can handle the derivation of the indices by precomputing a separate plan entry per output pixel. In PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and storing that point’s four indices, four weights, and mask in plans[idx]. During evaluation, EvaluatePlanForChannel iterates through the same point_count(H_out*W_out) and uses the matching plan entry for each (oy, ox). So we are not reusing one plan across different spatial positions; we precompute one plan per output location and reuse it only across channels (which share the same grid). --------- Signed-off-by: melkap01 <melike.kaptan@arm.com> Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 4fd9ebb commit b7a8c24

3 files changed

Lines changed: 298 additions & 52 deletions

File tree

onnxruntime/core/providers/cpu/tensor/grid_sample.cc

Lines changed: 221 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "core/providers/cpu/tensor/grid_sample.h"
4+
#include <vector>
55

6+
#include "core/providers/cpu/tensor/grid_sample.h"
67
#include "core/framework/element_type_lists.h"
78
#include "core/framework/TensorSeq.h"
89
#include "core/providers/common.h"
@@ -148,6 +149,162 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w,
148149
return pixel;
149150
}
150151

152+
namespace {
153+
154+
constexpr uint8_t kTopLeftMask = 1u << 0;
155+
constexpr uint8_t kTopRightMask = 1u << 1;
156+
constexpr uint8_t kBottomLeftMask = 1u << 2;
157+
constexpr uint8_t kBottomRightMask = 1u << 3;
158+
constexpr uint8_t kAllNeighborsMask = kTopLeftMask | kTopRightMask | kBottomLeftMask | kBottomRightMask;
159+
160+
template <typename T>
161+
struct BilinearSamplePlan2D {
162+
int64_t x1;
163+
int64_t x2;
164+
int64_t y1;
165+
int64_t y2;
166+
T w11;
167+
T w12;
168+
T w21;
169+
T w22;
170+
uint8_t mask = 0;
171+
};
172+
// PrecomputeBilinearSamplePlan2D, the loop runs across all H_out * W_out points, using the right nx/ny for each (oy, ox) and
173+
// storing that point's four indices, four weights, and mask in plans[idx]. This operation takes place only when bilinear interpolation
174+
// is used with zero padding and no align_corners set, and it helps to speed up the sampling by precomputing the plan for each output pixel.
175+
template <typename T>
176+
void PrecomputeBilinearSamplePlan2D(const T* grid_data,
177+
int64_t H_out,
178+
int64_t W_out,
179+
int64_t H_in,
180+
int64_t W_in,
181+
std::vector<BilinearSamplePlan2D<T>>& plans) {
182+
const size_t point_count = static_cast<size_t>(H_out) * static_cast<size_t>(W_out);
183+
184+
for (size_t idx = 0; idx < point_count; ++idx) {
185+
auto& plan = plans[idx];
186+
const T nx = grid_data[idx * 2];
187+
const T ny = grid_data[idx * 2 + 1];
188+
const T x = GsDenormalize<T>(nx, W_in, false);
189+
const T y = GsDenormalize<T>(ny, H_in, false);
190+
191+
const int64_t x1 = static_cast<int64_t>(std::floor(x));
192+
const int64_t y1 = static_cast<int64_t>(std::floor(y));
193+
const int64_t x2 = x1 + 1;
194+
const int64_t y2 = y1 + 1;
195+
196+
const T dx2 = static_cast<T>(x2) - x;
197+
const T dx1 = x - static_cast<T>(x1);
198+
const T dy2 = static_cast<T>(y2) - y;
199+
const T dy1 = y - static_cast<T>(y1);
200+
201+
uint8_t mask = 0;
202+
if (x1 >= 0 && x1 < W_in && y1 >= 0 && y1 < H_in) {
203+
mask |= kTopLeftMask;
204+
}
205+
if (x2 >= 0 && x2 < W_in && y1 >= 0 && y1 < H_in) {
206+
mask |= kTopRightMask;
207+
}
208+
if (x1 >= 0 && x1 < W_in && y2 >= 0 && y2 < H_in) {
209+
mask |= kBottomLeftMask;
210+
}
211+
if (x2 >= 0 && x2 < W_in && y2 >= 0 && y2 < H_in) {
212+
mask |= kBottomRightMask;
213+
}
214+
215+
plan.x1 = x1;
216+
plan.x2 = x2;
217+
plan.y1 = y1;
218+
plan.y2 = y2;
219+
plan.w11 = dy2 * dx2;
220+
plan.w12 = dy2 * dx1;
221+
plan.w21 = dy1 * dx2;
222+
plan.w22 = dy1 * dx1;
223+
plan.mask = mask;
224+
}
225+
}
226+
227+
template <typename T>
228+
void EvaluatePlanForChannel(const T* input_data,
229+
T* output_data,
230+
int64_t W_in,
231+
const BilinearSamplePlan2D<T>* plan_data,
232+
size_t point_count) {
233+
for (size_t idx = 0; idx < point_count; ++idx) {
234+
const auto& plan = plan_data[idx];
235+
if (plan.mask == 0) {
236+
output_data[idx] = T{};
237+
continue;
238+
}
239+
240+
T p11 = T{};
241+
T p12 = T{};
242+
T p21 = T{};
243+
T p22 = T{};
244+
245+
if (plan.mask == kAllNeighborsMask) {
246+
const int64_t row1 = plan.y1 * W_in;
247+
const int64_t row2 = plan.y2 * W_in;
248+
p11 = input_data[row1 + plan.x1];
249+
p12 = input_data[row1 + plan.x2];
250+
p21 = input_data[row2 + plan.x1];
251+
p22 = input_data[row2 + plan.x2];
252+
} else {
253+
if (plan.mask & kTopLeftMask) {
254+
p11 = input_data[plan.y1 * W_in + plan.x1];
255+
}
256+
if (plan.mask & kTopRightMask) {
257+
p12 = input_data[plan.y1 * W_in + plan.x2];
258+
}
259+
if (plan.mask & kBottomLeftMask) {
260+
p21 = input_data[plan.y2 * W_in + plan.x1];
261+
}
262+
if (plan.mask & kBottomRightMask) {
263+
p22 = input_data[plan.y2 * W_in + plan.x2];
264+
}
265+
}
266+
267+
output_data[idx] = plan.w11 * p11 + plan.w12 * p12 + plan.w21 * p21 + plan.w22 * p22;
268+
}
269+
}
270+
271+
template <typename T>
272+
void TryRunBilinearZerosFastPath2D(const Tensor& input,
273+
const Tensor& grid,
274+
Tensor& output,
275+
int64_t n,
276+
int64_t C,
277+
int64_t H_in,
278+
int64_t W_in,
279+
int64_t H_out,
280+
int64_t W_out,
281+
concurrency::ThreadPool* tp,
282+
std::vector<BilinearSamplePlan2D<T>>& sampling_plan) {
283+
const size_t plane_in = static_cast<size_t>(H_in) * static_cast<size_t>(W_in);
284+
const size_t plane_out = static_cast<size_t>(H_out) * static_cast<size_t>(W_out);
285+
sampling_plan.resize(plane_out);
286+
287+
const T* grid_data = grid.Data<T>() + n * plane_out * 2;
288+
PrecomputeBilinearSamplePlan2D(grid_data, H_out, W_out, H_in, W_in, sampling_plan);
289+
290+
const T* input_data = input.Data<T>();
291+
T* output_data = output.MutableData<T>();
292+
293+
if (plane_out == 0) {
294+
return;
295+
}
296+
297+
concurrency::ThreadPool::TrySimpleParallelFor(
298+
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
299+
[&](std::ptrdiff_t c) {
300+
const T* X_data = input_data + (n * C + c) * plane_in;
301+
T* Y_data = output_data + (n * C + c) * plane_out;
302+
EvaluatePlanForChannel(X_data, Y_data, W_in, sampling_plan.data(), plane_out);
303+
});
304+
}
305+
306+
} // namespace
307+
151308
// When grid sampling, padding is applied before interpolation.
152309
// For instance, in bilinear mode and zeros padding-mode, pixel p at actual
153310
// image location (-0.5, -0.5)
@@ -210,61 +367,73 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
210367
T border[] = {x_min, y_min, x_max, y_max}; // l-t-r-b
211368

212369
concurrency::ThreadPool* tp = H_out * W_out > 64 ? context->GetOperatorThreadPool() : nullptr;
213-
for (int64_t n = 0; n < N; n++) {
214-
const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2;
215-
concurrency::ThreadPool::TrySimpleParallelFor(
216-
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
217-
[&](std::ptrdiff_t c) {
218-
const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in);
219-
T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out);
220-
221-
for (int64_t oy = 0; oy < H_out; oy++) {
222-
for (int64_t ox = 0; ox < W_out; ox++) {
223-
const T* gridpoint = grid_data + (oy * W_out + ox) * 2;
224-
T* Y_gridpoint = Y_data + oy * W_out + ox;
225-
auto nx = gridpoint[0]; // normalized location
226-
auto ny = gridpoint[1];
227-
auto x = GsDenormalize<T>(nx, W_in, align_corners_); // actual location
228-
auto y = GsDenormalize<T>(ny, H_in, align_corners_);
229-
230-
if (mode_ == Nearest) {
231-
x = static_cast<T>(std::nearbyint(static_cast<T>(x)));
232-
y = static_cast<T>(std::nearbyint(static_cast<T>(y)));
233-
// x, y are integers in all padding modes
234-
*Y_gridpoint = PixelAtGrid(X_data, static_cast<int64_t>(y), static_cast<int64_t>(x), H_in, W_in, border);
235-
} else if (mode_ == Linear) {
236-
int64_t x1 = static_cast<int64_t>(std::floor(x));
237-
int64_t y1 = static_cast<int64_t>(std::floor(y));
238-
int64_t x2 = x1 + 1;
239-
int64_t y2 = y1 + 1;
240-
241-
T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border);
242-
T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border);
243-
T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border);
244-
T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border);
245-
246-
T dx2 = static_cast<T>(x2) - x;
247-
T dx1 = x - static_cast<T>(x1);
248-
T dy2 = static_cast<T>(y2) - y;
249-
T dy1 = y - static_cast<T>(y1);
250-
*Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
251-
} else if (mode_ == Cubic) {
252-
int64_t x0 = static_cast<int64_t>(std::floor(x)) - 1; // top-left corner of the bbox
253-
int64_t y0 = static_cast<int64_t>(std::floor(y)) - 1;
254-
255-
T p[4][4] = {}; // [H][W]
256-
for (int64_t h = 0; h < 4; h++) {
257-
for (int64_t w = 0; w < 4; w++) {
258-
p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border);
370+
371+
if (mode_ == Linear && padding_mode_ == Zeros && !align_corners_) {
372+
std::vector<BilinearSamplePlan2D<T>> sampling_plan;
373+
for (int64_t n = 0; n < N; n++) {
374+
// Fast path for bilinear interpolation with zero padding when align_corners is false.
375+
// Out-of-bounds neighbors are handled via masked loads and implicitly treated as zeros,
376+
// and sampling_plan precomputes a separate plan entry per output pixel to avoid per-pixel
377+
// boundary checks in the main loop.
378+
TryRunBilinearZerosFastPath2D(*input, *grid, Y, n, C, H_in, W_in, H_out, W_out, tp, sampling_plan);
379+
}
380+
} else {
381+
for (int64_t n = 0; n < N; n++) {
382+
const T* grid_data = grid->Data<T>() + n * (H_out * W_out) * 2;
383+
concurrency::ThreadPool::TrySimpleParallelFor(
384+
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
385+
[&](std::ptrdiff_t c) {
386+
const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in);
387+
T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out);
388+
389+
for (int64_t oy = 0; oy < H_out; oy++) {
390+
for (int64_t ox = 0; ox < W_out; ox++) {
391+
const T* gridpoint = grid_data + (oy * W_out + ox) * 2;
392+
T* Y_gridpoint = Y_data + oy * W_out + ox;
393+
auto nx = gridpoint[0]; // normalized location
394+
auto ny = gridpoint[1];
395+
auto x = GsDenormalize<T>(nx, W_in, align_corners_); // actual location
396+
auto y = GsDenormalize<T>(ny, H_in, align_corners_);
397+
398+
if (mode_ == Nearest) {
399+
x = static_cast<T>(std::nearbyint(static_cast<T>(x)));
400+
y = static_cast<T>(std::nearbyint(static_cast<T>(y)));
401+
// x, y are integers in all padding modes
402+
*Y_gridpoint = PixelAtGrid(X_data, static_cast<int64_t>(y), static_cast<int64_t>(x), H_in, W_in, border);
403+
} else if (mode_ == Linear) {
404+
int64_t x1 = static_cast<int64_t>(std::floor(x));
405+
int64_t y1 = static_cast<int64_t>(std::floor(y));
406+
int64_t x2 = x1 + 1;
407+
int64_t y2 = y1 + 1;
408+
409+
T p11 = PixelAtGrid(X_data, y1, x1, H_in, W_in, border);
410+
T p12 = PixelAtGrid(X_data, y1, x2, H_in, W_in, border);
411+
T p21 = PixelAtGrid(X_data, y2, x1, H_in, W_in, border);
412+
T p22 = PixelAtGrid(X_data, y2, x2, H_in, W_in, border);
413+
414+
T dx2 = static_cast<T>(x2) - x;
415+
T dx1 = x - static_cast<T>(x1);
416+
T dy2 = static_cast<T>(y2) - y;
417+
T dy1 = y - static_cast<T>(y1);
418+
*Y_gridpoint = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
419+
} else if (mode_ == Cubic) {
420+
int64_t x0 = static_cast<int64_t>(std::floor(x)) - 1; // top-left corner of the bbox
421+
int64_t y0 = static_cast<int64_t>(std::floor(y)) - 1;
422+
423+
T p[4][4] = {}; // [H][W]
424+
for (int64_t h = 0; h < 4; h++) {
425+
for (int64_t w = 0; w < 4; w++) {
426+
p[h][w] = PixelAtGrid(X_data, h + y0, w + x0, H_in, W_in, border);
427+
}
259428
}
429+
T dx = static_cast<T>(x - x0 - 1);
430+
T dy = static_cast<T>(y - y0 - 1);
431+
*Y_gridpoint = GsBicubicInterpolate(p, dx, dy);
260432
}
261-
T dx = static_cast<T>(x - x0 - 1);
262-
T dy = static_cast<T>(y - y0 - 1);
263-
*Y_gridpoint = GsBicubicInterpolate(p, dx, dy);
264433
}
265434
}
266-
}
267-
});
435+
});
436+
}
268437
}
269438
} else if (data_dims == 3) {
270439
// sample 3d;

onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ void RunTests(T& test, std::vector<std::unique_ptr<IExecutionProvider>>&& execut
3838
execution_providers.clear();
3939
}
4040

41+
// Custom tests not generated by grid_sample_test_gen.py.
42+
#include "test/providers/cpu/tensor/grid_sample_test_custom.inc"
43+
4144
// DO NOT edit following tests. They are generated by:
4245
// onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
4346
template <typename T>
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
2+
// Custom tests are kept in a separate include to avoid regenerating the main file.
3+
4+
template <typename T>
5+
class GridSampleCustomTest : public ::testing::Test {
6+
};
7+
8+
using GridSampleCustomTestTypes = ::testing::Types<float, MLFloat16>;
9+
TYPED_TEST_SUITE(GridSampleCustomTest, GridSampleCustomTestTypes);
10+
11+
TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds_right_bottom) {
12+
// Crafts grid points that mix fully in-bounds sampling with cases where either the right, bottom,
13+
// or both neighbors fall outside the source image so zero padding must be applied. This ensures
14+
// the optimized bilinear fast path matches the generic implementation for boundary handling.
15+
OpTester test("GridSample", 20);
16+
std::string mode = "linear";
17+
std::string padding_mode = "zeros";
18+
int64_t align_corners = 0;
19+
std::initializer_list<int64_t> X_shape{1, 1, 2, 2};
20+
std::initializer_list<TypeParam> X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)};
21+
std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2};
22+
// (nx, ny) pairs: center (in-bounds), right edge (x out), bottom edge (y out), corner (both out)
23+
std::initializer_list<TypeParam> Grid_data{
24+
TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds)
25+
TypeParam(0.9f), TypeParam(0.0f), // near right edge (right neighbors out of bounds)
26+
TypeParam(0.0f), TypeParam(0.9f), // near bottom edge (bottom neighbors out)
27+
TypeParam(0.9f), TypeParam(0.9f)}; // near bottom-right corner (both right and bottom neighbors out)
28+
std::initializer_list<int64_t> Y_shape{1, 1, 2, 2};
29+
std::initializer_list<TypeParam> Y_data{
30+
TypeParam(2.5f), // all neighbors in bounds
31+
TypeParam(1.8f), // right neighbors partially out-of-bounds
32+
TypeParam(2.1f), // bottom neighbors partially out-of-bounds
33+
TypeParam(1.44f)}; // both right and bottom neighbors out-of-bounds
34+
test.AddInput<TypeParam>("X", X_shape, X_data);
35+
test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data);
36+
test.AddAttribute("mode", mode);
37+
test.AddAttribute("padding_mode", padding_mode);
38+
test.AddAttribute("align_corners", align_corners);
39+
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
40+
RunTests(test, GetExecutionProviders(20));
41+
}
42+
43+
TYPED_TEST(GridSampleCustomTest, test_grid_sample_20_4D_linear_zeros_mixed_bounds_left_top) {
44+
// Similar to test_grid_sample_20_4D_linear_zeros_mixed_bounds_right_bottom but focuses on left/top boundary cases,
45+
// where the left and/or top neighbors fall outside the source image and zero padding must be applied.
46+
// This ensures the optimized bilinear fast path correctly handles left/top boundary conditions.
47+
OpTester test("GridSample", 20);
48+
std::string mode = "linear";
49+
std::string padding_mode = "zeros";
50+
int64_t align_corners = 0;
51+
std::initializer_list<int64_t> X_shape{1, 1, 2, 2};
52+
std::initializer_list<TypeParam> X_data{TypeParam(1.0f), TypeParam(2.0f), TypeParam(3.0f), TypeParam(4.0f)};
53+
std::initializer_list<int64_t> Grid_shape{1, 2, 2, 2};
54+
// (nx, ny) pairs: center (in-bounds), left edge (x out), top edge (y out), corner (both out)
55+
std::initializer_list<TypeParam> Grid_data{
56+
TypeParam(0.0f), TypeParam(0.0f), // center (all neighbors in bounds)
57+
TypeParam(-0.9f), TypeParam(0.0f), // near left edge (left neighbors out of bounds)
58+
TypeParam(0.0f), TypeParam(-0.9f), // near top edge (top neighbors out of bounds)
59+
TypeParam(-0.9f), TypeParam(-0.9f)}; // near top-left corner (both left and top neighbors out of bounds)
60+
std::initializer_list<int64_t> Y_shape{1, 1, 2, 2};
61+
std::initializer_list<TypeParam> Y_data{
62+
TypeParam(2.5f), // all neighbors in bounds
63+
TypeParam(1.2f), // left neighbors partially out-of-bounds
64+
TypeParam(0.9f), // top neighbors partially out-of-bounds
65+
TypeParam(0.36f)}; // both left and top neighbors out-of-bounds
66+
test.AddInput<TypeParam>("X", X_shape, X_data);
67+
test.AddInput<TypeParam>("Grid", Grid_shape, Grid_data);
68+
test.AddAttribute("mode", mode);
69+
test.AddAttribute("padding_mode", padding_mode);
70+
test.AddAttribute("align_corners", align_corners);
71+
test.AddOutput<TypeParam>("Y", Y_shape, Y_data);
72+
RunTests(test, GetExecutionProviders(20));
73+
}
74+

0 commit comments

Comments
 (0)