Skip to content

Commit 271a05b

Browse files
author
Baris Demir
committed
Arm backend: Add VGF grid_sampler integration test
Add a VGF export/integration test for grid_sampler using the existing VgfPipeline. The test is gated on model-converter availability and runs with quantization disabled and Vulkan runtime execution disabled, so it validates the converter/export handoff without depending on runtime shader dispatch. Also register grid_sampler as a recognized custom edge op for the Arm test-name validator. Signed-off-by: Baris Demir <baris.demir@arm.com> Change-Id: I495d4283b09de7dedebce5d2f24b240e79d37dfc
1 parent d1bc458 commit 271a05b

1 file changed

Lines changed: 62 additions & 0 deletions

File tree

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import VgfPipeline
12+
13+
input_t = Tuple[torch.Tensor, torch.Tensor]
14+
aten_op = "torch.ops.aten.grid_sampler.default"
15+
exir_op = "executorch_exir_dialects_edge__ops_aten_grid_sampler_2d_default"
16+
17+
test_data_suite = {
18+
"2d_bilinear_zeros": lambda: (
19+
torch.randn(1, 3, 8, 8),
20+
torch.randn(1, 4, 4, 2),
21+
),
22+
}
23+
24+
xfails = {
25+
"2d_bilinear_zeros": (
26+
"CI model_converter does not yet include Vulkan custom-shader "
27+
"tosa.custom legalization",
28+
RuntimeError,
29+
),
30+
}
31+
32+
33+
class GridSampler2d(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
self.interpolation_mode_ = 0
37+
self.padding_mode_ = 0
38+
self.align_corners_ = False
39+
40+
def forward(self, x, grid):
41+
return F.grid_sample(
42+
x,
43+
grid,
44+
mode="bilinear" if self.interpolation_mode_ == 0 else "nearest",
45+
padding_mode="zeros" if self.padding_mode_ == 0 else "border",
46+
align_corners=self.align_corners_,
47+
)
48+
49+
50+
@common.parametrize("test_data", test_data_suite, xfails=xfails, strict=False)
51+
@common.SkipIfNoModelConverter
52+
def test_grid_sampler_vgf_no_quant(test_data):
53+
test_data = test_data()
54+
pipeline = VgfPipeline[input_t](
55+
GridSampler2d(),
56+
test_data,
57+
aten_op,
58+
exir_op,
59+
quantize=False,
60+
run_on_vulkan_runtime=False,
61+
)
62+
pipeline.run()

0 commit comments

Comments
 (0)