Skip to content

Commit b5f8155

Browse files
authored
Arm backend: Lower grid_sampler_2d to VGF TOSA CUSTOM (#19547)
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani --------- Signed-off-by: Baris Demir <baris.demir@arm.com>
1 parent 6663aea commit b5f8155

16 files changed

Lines changed: 636 additions & 4 deletions

backends/arm/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,27 @@ runtime.python_library(
8787
name = "vgf",
8888
srcs = [
8989
"vgf/__init__.py",
90+
"vgf/_passes/__init__.py",
91+
"vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py",
9092
"vgf/backend.py",
9193
"vgf/compile_spec.py",
9294
"vgf/model_converter.py",
9395
"vgf/partitioner.py",
96+
"vgf/shaders/__init__.py",
97+
"vgf/shaders/grid_sampler.py",
98+
],
99+
resources = [
100+
"vgf/shaders/grid_sampler.glsl",
101+
"vgf/shaders/grid_sampler.spirv.b64",
94102
],
95103
deps = [
96104
":arm_compile_spec",
105+
"//caffe2:torch",
106+
"//executorch/backends/arm/_passes:passes",
107+
"//executorch/backends/arm/tosa/dialect:lib",
97108
"//executorch/backends/arm/tosa:specification",
98109
"//executorch/backends/arm/tosa:partitioner",
110+
"//executorch/exir:lib",
99111
],
100112
)
101113

backends/arm/ethosu/partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from typing import final, Optional, Sequence
77

8-
import torch
98
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
109
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1110
from executorch.exir.backend.partitioner import DelegationSpec
11+
from torch._ops import OpOverload
1212
from torch.fx.passes.operator_support import OperatorSupportBase
1313

1414

@@ -33,5 +33,5 @@ def __init__(
3333
)
3434
self.additional_checks = additional_checks
3535
self.tosa_spec = compile_spec.tosa_spec
36-
self._custom_partition_ops: set[torch._ops.OpOverload] = set()
36+
self._custom_partition_ops: set[OpOverload] = set()
3737
self.intermediate_path = compile_spec._get_intermediate_path()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
import base64
8+
import shutil
9+
import subprocess # nosec B404 - required to invoke the shader compiler.
10+
import tempfile
11+
from pathlib import Path
12+
13+
14+
SHADER_DIR = Path(__file__).resolve().parents[1] / "vgf" / "shaders"
15+
DEFAULT_SOURCE = SHADER_DIR / "grid_sampler.glsl"
16+
DEFAULT_OUTPUT = SHADER_DIR / "grid_sampler.spirv.b64"
17+
18+
19+
def _parse_args() -> argparse.Namespace:
20+
parser = argparse.ArgumentParser(
21+
description=(
22+
"Compile the VGF grid_sampler GLSL shader to SPIR-V and write the "
23+
"base64-encoded payload consumed by the ExecuTorch custom-shader "
24+
"lowering."
25+
)
26+
)
27+
parser.add_argument(
28+
"--source",
29+
type=Path,
30+
default=DEFAULT_SOURCE,
31+
help=f"GLSL source file. Defaults to {DEFAULT_SOURCE}",
32+
)
33+
parser.add_argument(
34+
"--output",
35+
type=Path,
36+
default=DEFAULT_OUTPUT,
37+
help=f"Base64 SPIR-V output file. Defaults to {DEFAULT_OUTPUT}",
38+
)
39+
parser.add_argument(
40+
"--glslc",
41+
default="glslc",
42+
help="Path to glslc. Defaults to resolving glslc from PATH.",
43+
)
44+
return parser.parse_args()
45+
46+
47+
def _resolve_glslc(glslc: str) -> str:
48+
resolved = shutil.which(glslc)
49+
if resolved is None:
50+
raise RuntimeError(
51+
f"Could not find {glslc}. Install the Vulkan SDK or pass --glslc."
52+
)
53+
return resolved
54+
55+
56+
def _write_base64_spirv(spirv_path: Path, output_path: Path) -> None:
57+
encoded = base64.b64encode(spirv_path.read_bytes()).decode("ascii")
58+
output_path.write_text(encoded + "\n", encoding="utf-8")
59+
60+
61+
def main() -> None:
62+
args = _parse_args()
63+
glslc = _resolve_glslc(args.glslc)
64+
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
spirv_path = Path(tmpdir) / "grid_sampler.spirv"
67+
subprocess.run( # nosec B603 - glslc path is resolved explicitly.
68+
[glslc, str(args.source), "-o", str(spirv_path)],
69+
check=True,
70+
)
71+
_write_base64_spirv(spirv_path, args.output)
72+
73+
74+
if __name__ == "__main__":
75+
main()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import base64
7+
8+
import pytest
9+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
10+
build_grid_sampler_2d_payload,
11+
decode_payload,
12+
encode_payload,
13+
GRID_SAMPLER_2D_SHADER_BINARY,
14+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
15+
GRID_SAMPLER_2D_SHADER_LANGUAGE,
16+
GRID_SAMPLER_2D_SHADER_SOURCE,
17+
GRID_SAMPLER_2D_VK_FORMAT,
18+
GRID_SAMPLER_2D_WORKGROUP_SIZES,
19+
)
20+
21+
22+
def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip():
23+
payload = build_grid_sampler_2d_payload(
24+
interpolation_mode=0,
25+
padding_mode=2,
26+
align_corners=True,
27+
)
28+
decoded = decode_payload(encode_payload(payload))
29+
30+
assert decoded["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
31+
assert decoded["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
32+
assert decoded["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
33+
assert base64.b64decode(decoded["shader_code"])[:4] == b"\x03\x02\x23\x07"
34+
assert decoded["input_0_type"] == "Tensor"
35+
assert decoded["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
36+
assert decoded["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
37+
assert decoded["input_0_binding"] == 0
38+
assert decoded["input_1_type"] == "Tensor"
39+
assert decoded["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
40+
assert decoded["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
41+
assert decoded["input_1_binding"] == 1
42+
assert decoded["output_0_type"] == "Tensor"
43+
assert decoded["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
44+
assert decoded["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
45+
assert decoded["output_0_binding"] == 2
46+
47+
48+
def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv():
49+
payload = build_grid_sampler_2d_payload(
50+
interpolation_mode=0,
51+
padding_mode=0,
52+
align_corners=False,
53+
)
54+
55+
shader_binary = base64.b64decode(payload["shader_code"])
56+
57+
assert payload["shader_language"] == "SPIR-V"
58+
assert shader_binary[:4] == b"\x03\x02\x23\x07"
59+
60+
61+
def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources():
62+
assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl"
63+
assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64"
64+
65+
66+
def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes():
67+
with pytest.raises(ValueError, match="Unsupported interpolation_mode"):
68+
build_grid_sampler_2d_payload(
69+
interpolation_mode=99,
70+
padding_mode=0,
71+
align_corners=False,
72+
)
73+
74+
with pytest.raises(ValueError, match="Unsupported padding_mode"):
75+
build_grid_sampler_2d_payload(
76+
interpolation_mode=0,
77+
padding_mode=99,
78+
align_corners=False,
79+
)

backends/arm/test/misc/test_extract_io_params_tosa.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
import torch
10+
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
1011
from executorch.backends.arm.quantizer import VgfQuantizer
1112
from executorch.backends.arm.quantizer.arm_quantizer import (
1213
get_symmetric_quantization_config,
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1920
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
2021
from executorch.exir import to_edge_transform_and_lower
22+
from executorch.exir.dialects._ops import ops as exir_ops
2123
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
2224
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2325

@@ -88,3 +90,26 @@ def test_roundtrip_extracts_io_params_tosa_INT(
8890
assert isinstance(out_name, str)
8991
assert isinstance(out_params["scale"], float)
9092
assert isinstance(out_params["zero_point"], int)
93+
94+
95+
def test_only_vgf_partitioner_registers_grid_sampler_no_target_custom_partition_op():
96+
tosa_partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))
97+
vgf_partitioner = VgfPartitioner(VgfCompileSpec("TOSA-1.0+FP"))
98+
ethosu_partitioner = EthosUPartitioner(EthosUCompileSpec("ethos-u55-128"))
99+
100+
assert hasattr(tosa_partitioner, "_custom_partition_ops")
101+
assert hasattr(vgf_partitioner, "_custom_partition_ops")
102+
assert hasattr(ethosu_partitioner, "_custom_partition_ops")
103+
104+
assert (
105+
exir_ops.edge.aten.grid_sampler_2d.default
106+
not in tosa_partitioner._custom_partition_ops
107+
)
108+
assert (
109+
exir_ops.edge.aten.grid_sampler_2d.default
110+
in vgf_partitioner._custom_partition_ops
111+
)
112+
assert (
113+
exir_ops.edge.aten.grid_sampler_2d.default
114+
not in ethosu_partitioner._custom_partition_ops
115+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import VgfPipeline
12+
13+
input_t = Tuple[torch.Tensor, torch.Tensor]
14+
aten_op = "torch.ops.aten.grid_sampler.default"
15+
exir_op = "executorch_exir_dialects_edge__ops_aten_grid_sampler_2d_default"
16+
17+
test_data_suite = {
18+
"2d_bilinear_zeros": lambda: (
19+
torch.randn(1, 3, 8, 8),
20+
torch.randn(1, 4, 4, 2),
21+
),
22+
}
23+
24+
xfails = {
25+
"2d_bilinear_zeros": (
26+
"CI model_converter does not yet include Vulkan custom-shader "
27+
"tosa.custom legalization",
28+
RuntimeError,
29+
),
30+
}
31+
32+
33+
class GridSampler2d(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
self.interpolation_mode_ = 0
37+
self.padding_mode_ = 0
38+
self.align_corners_ = False
39+
40+
def forward(self, x, grid):
41+
return F.grid_sample(
42+
x,
43+
grid,
44+
mode="bilinear" if self.interpolation_mode_ == 0 else "nearest",
45+
padding_mode="zeros" if self.padding_mode_ == 0 else "border",
46+
align_corners=self.align_corners_,
47+
)
48+
49+
50+
@common.parametrize("test_data", test_data_suite, xfails=xfails, strict=False)
51+
@common.SkipIfNoModelConverter
52+
def test_grid_sampler_vgf_no_quant(test_data):
53+
test_data = test_data()
54+
pipeline = VgfPipeline[input_t](
55+
GridSampler2d(),
56+
test_data,
57+
aten_op,
58+
exir_op,
59+
quantize=False,
60+
run_on_vulkan_runtime=False,
61+
)
62+
pipeline.run()
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import torch
8+
import torch.nn.functional as F
9+
from executorch.backends.arm.tosa.specification import (
10+
TosaLoweringContext,
11+
TosaSpecification,
12+
)
13+
from executorch.backends.arm.vgf._passes.rewrite_grid_sampler_to_tosa_custom import (
14+
RewriteGridSamplerToTosaCustomPass,
15+
)
16+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
17+
CUSTOM_SHADER_DOMAIN_NAME,
18+
decode_payload,
19+
GRID_SAMPLER_2D_OPERATOR_NAME,
20+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
21+
GRID_SAMPLER_2D_SHADER_LANGUAGE,
22+
GRID_SAMPLER_2D_VK_FORMAT,
23+
GRID_SAMPLER_2D_WORKGROUP_SIZES,
24+
)
25+
from executorch.exir import to_edge
26+
from executorch.exir.dialects._ops import ops as exir_ops
27+
from torch.export import export
28+
29+
30+
class GridSampler2d(torch.nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
self.interpolation_mode_ = 0
34+
self.padding_mode_ = 0
35+
self.align_corners_ = False
36+
37+
def forward(self, x, grid):
38+
return F.grid_sample(
39+
x,
40+
grid,
41+
mode="bilinear" if self.interpolation_mode_ == 0 else "nearest",
42+
padding_mode="zeros" if self.padding_mode_ == 0 else "border",
43+
align_corners=self.align_corners_,
44+
)
45+
46+
47+
def test_rewrite_grid_sampler_to_tosa_custom_no_target():
48+
model = GridSampler2d()
49+
example_inputs = (
50+
torch.randn(1, 3, 8, 8),
51+
torch.randn(1, 4, 4, 2),
52+
)
53+
54+
edge_model = to_edge(export(model, example_inputs))
55+
nodes = list(edge_model.exported_program().graph.nodes)
56+
57+
assert any(
58+
node.target == exir_ops.edge.aten.grid_sampler_2d.default for node in nodes
59+
)
60+
61+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")):
62+
edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()])
63+
nodes = list(edge_model.exported_program().graph.nodes)
64+
65+
assert not any(
66+
node.target == exir_ops.edge.aten.grid_sampler_2d.default for node in nodes
67+
)
68+
69+
custom_node = next(
70+
node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default
71+
)
72+
assert custom_node.kwargs["operator_name"] == GRID_SAMPLER_2D_OPERATOR_NAME
73+
assert custom_node.kwargs["domain_name"] == CUSTOM_SHADER_DOMAIN_NAME
74+
75+
payload = decode_payload(custom_node.kwargs["implementation_attrs"])
76+
assert payload["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
77+
assert payload["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
78+
assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
79+
assert payload["input_0_type"] == "Tensor"
80+
assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
81+
assert payload["input_0_binding"] == 0
82+
assert payload["input_0_descriptorset"] == 0
83+
assert payload["input_1_type"] == "Tensor"
84+
assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
85+
assert payload["input_1_binding"] == 1
86+
assert payload["input_1_descriptorset"] == 0
87+
assert payload["output_0_type"] == "Tensor"
88+
assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
89+
assert payload["output_0_binding"] == 2
90+
assert payload["output_0_descriptorset"] == 0

0 commit comments

Comments
 (0)