Skip to content

Commit 49f2ea9

Browse files
committed
[PICKED] Arm backend: Lower grid_sampler_2d to VGF TOSA CUSTOM pytorch#19547
Signed-off-by: Rob Elliott <Robert.Elliott@arm.com> Change-Id: Ic30ddbbdbf0a08d724d2dde9d2f6432918091932 Signed-off-by: Rob Elliott <Robert.Elliott@arm.com>
1 parent 47b71d8 commit 49f2ea9

16 files changed

Lines changed: 609 additions & 9 deletions

backends/arm/ethosu/partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from typing import final, Optional, Sequence
77

8-
import torch
98
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
109
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1110
from executorch.exir.backend.partitioner import DelegationSpec
11+
from torch._ops import OpOverload
1212
from torch.fx.passes.operator_support import OperatorSupportBase
1313

1414

@@ -33,5 +33,5 @@ def __init__(
3333
)
3434
self.additional_checks = additional_checks
3535
self.tosa_spec = compile_spec.tosa_spec
36-
self._custom_partition_ops: set[torch._ops.OpOverload] = set()
36+
self._custom_partition_ops: set[OpOverload] = set()
3737
self.intermediate_path = compile_spec._get_intermediate_path()

backends/arm/scripts/docgen/ethos-u/backends-arm-ethos-u-overview.md.in

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ The Arm&reg; Ethos&trade;-U backend targets Edge/IoT-type AI use-cases by enabli
44
[Arm&reg; Ethos&trade;-U55 NPU](https://www.arm.com/products/silicon-ip-cpu/ethos/ethos-u55), [Arm&reg; Ethos&trade;-U65 NPU](https://www.arm.com/products/silicon-ip-cpu/ethos/ethos-u65), and
55
[Arm&reg; Ethos&trade;-U85 NPU](https://www.arm.com/products/silicon-ip-cpu/ethos/ethos-u85), leveraging [TOSA](https://www.mlplatform.org/tosa/) and the
66
[ethos-u-vela](https://pypi.org/project/ethos-u-vela/) graph compiler. This document is a technical reference for using the Ethos-U backend, for a top level view with code examples
7-
please refer to the [Arm Ethos-U Backend Tutorial](tutorials/ethos-u-getting-started.md). <!-- @lint-ignore -->
7+
please refer to the [Arm Ethos-U Backend Tutorial](tutorials/ethos-u-getting-started.md).
88

99
## Features
1010

@@ -27,7 +27,7 @@ For the AOT flow, compilation of a model to `.pte` format using the Ethos-U back
2727
- [TOSA Serialization Library](https://www.mlplatform.org/tosa/software.html) for serializing the Exir IR graph into TOSA IR.
2828
- [Ethos-U Vela graph compiler](https://pypi.org/project/ethos-u-vela/) for compiling TOSA flatbuffers into an Ethos-U command stream.
2929

30-
And for building and running the example application available in `examples/arm/executor_runner/` through the standalone CMake entry point:
30+
And for building and running the example application available in `examples/arm/executor_runner/`:
3131
- [Arm GNU Toolchain](https://developer.arm.com/Tools%20and%20Software/GNU%20Toolchain) for cross compilation.
3232
- [Arm&reg; Corstone&trade; SSE-300 FVP](https://developer.arm.com/documentation/100966/1128/Arm--Corstone-SSE-300-FVP) for testing on a Arm&reg; Cortex&reg;-M55+Ethos-U55 reference design.
3333
- [Arm&reg; Corstone&trade; SSE-320 FVP](https://developer.arm.com/documentation/109760/0000/SSE-320-FVP) for testing on a Arm&reg; Cortex&reg;-M85+Ethos-U85 reference design.
@@ -55,7 +55,7 @@ For more information on quantization, see [Quantization](arm-ethos-u-quantizatio
5555

5656
## Runtime Integration
5757

58-
An example runtime application is available in [examples/arm/executor_runner](https://github.com/pytorch/executorch/blob/main/examples/arm/executor_runner/), with a standalone CMake entry point in `examples/arm/executor_runner/standalone`. The steps required for building and deploying it on an FVP are explained in the previously mentioned [Arm Ethos-U Backend Tutorial](tutorials/ethos-u-getting-started.md). <!-- @lint-ignore -->
58+
An example runtime application is available in [examples/arm/executor_runner](https://github.com/pytorch/executorch/blob/main/examples/arm/executor_runner/), and the steps requried for building and deploying it on a FVP it is explained in the previously mentioned [Arm Ethos-U Backend Tutorial](tutorials/ethos-u-getting-started.md).
5959
The example application is recommended to use for testing basic functionality of your lowered models, as well as a starting point for developing runtime integrations for your own targets.
6060
For an in-depth explanation of the architecture of the executor_runner and the steps required for doing such an integration, please refer to [Ethos-U porting guide](https://github.com/pytorch/executorch/blob/main/examples/arm/ethos-u-porting-guide.md).
6161

@@ -153,7 +153,7 @@ ExecuTorch for the Ethos-U backend, you automatically install the compiler conta
153153

154154
**→{doc}`/backends/arm-ethos-u/arm-ethos-u-troubleshooting` — Troubleshooting and common issues.**
155155

156-
**→{doc}`/backends/arm-ethos-u/tutorials/ethos-u-getting-started` — Getting started tutorial.**
156+
**→{doc}`/backends/arm-ethos-u/tutorials/arm-ethos-u-tutorials` — Tutorials.**
157157

158158
**→{doc}`/backends/arm-ethos-u/U55_op_support` — Ethos-U55 supported operators.**
159159

@@ -168,7 +168,7 @@ ExecuTorch for the Ethos-U backend, you automatically install the compiler conta
168168
arm-ethos-u-partitioner
169169
arm-ethos-u-quantization
170170
arm-ethos-u-troubleshooting
171-
tutorials/ethos-u-getting-started
171+
tutorials/arm-ethos-u-tutorials
172172
U55_op_support
173173
U85_op_support
174174
```
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
import base64
8+
import shutil
9+
import subprocess # nosec B404 - required to invoke the shader compiler.
10+
import tempfile
11+
from pathlib import Path
12+
13+
14+
SHADER_DIR = Path(__file__).resolve().parents[1] / "vgf" / "shaders"
15+
DEFAULT_SOURCE = SHADER_DIR / "grid_sampler.glsl"
16+
DEFAULT_OUTPUT = SHADER_DIR / "grid_sampler.spirv.b64"
17+
18+
19+
def _parse_args() -> argparse.Namespace:
20+
parser = argparse.ArgumentParser(
21+
description=(
22+
"Compile the VGF grid_sampler GLSL shader to SPIR-V and write the "
23+
"base64-encoded payload consumed by the ExecuTorch custom-shader "
24+
"lowering."
25+
)
26+
)
27+
parser.add_argument(
28+
"--source",
29+
type=Path,
30+
default=DEFAULT_SOURCE,
31+
help=f"GLSL source file. Defaults to {DEFAULT_SOURCE}",
32+
)
33+
parser.add_argument(
34+
"--output",
35+
type=Path,
36+
default=DEFAULT_OUTPUT,
37+
help=f"Base64 SPIR-V output file. Defaults to {DEFAULT_OUTPUT}",
38+
)
39+
parser.add_argument(
40+
"--glslc",
41+
default="glslc",
42+
help="Path to glslc. Defaults to resolving glslc from PATH.",
43+
)
44+
return parser.parse_args()
45+
46+
47+
def _resolve_glslc(glslc: str) -> str:
48+
resolved = shutil.which(glslc)
49+
if resolved is None:
50+
raise RuntimeError(
51+
f"Could not find {glslc}. Install the Vulkan SDK or pass --glslc."
52+
)
53+
return resolved
54+
55+
56+
def _write_base64_spirv(spirv_path: Path, output_path: Path) -> None:
57+
encoded = base64.b64encode(spirv_path.read_bytes()).decode("ascii")
58+
output_path.write_text(encoded + "\n", encoding="utf-8")
59+
60+
61+
def main() -> None:
62+
args = _parse_args()
63+
glslc = _resolve_glslc(args.glslc)
64+
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
spirv_path = Path(tmpdir) / "grid_sampler.spirv"
67+
subprocess.run( # nosec B603 - glslc path is resolved explicitly.
68+
[glslc, str(args.source), "-o", str(spirv_path)],
69+
check=True,
70+
)
71+
_write_base64_spirv(spirv_path, args.output)
72+
73+
74+
if __name__ == "__main__":
75+
main()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import base64
7+
8+
import pytest
9+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
10+
build_grid_sampler_2d_payload,
11+
decode_payload,
12+
encode_payload,
13+
GRID_SAMPLER_2D_SHADER_BINARY,
14+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
15+
GRID_SAMPLER_2D_SHADER_LANGUAGE,
16+
GRID_SAMPLER_2D_SHADER_SOURCE,
17+
GRID_SAMPLER_2D_VK_FORMAT,
18+
GRID_SAMPLER_2D_WORKGROUP_SIZES,
19+
)
20+
21+
22+
def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip():
23+
payload = build_grid_sampler_2d_payload(
24+
interpolation_mode=0,
25+
padding_mode=2,
26+
align_corners=True,
27+
)
28+
decoded = decode_payload(encode_payload(payload))
29+
30+
assert decoded["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
31+
assert decoded["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
32+
assert decoded["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
33+
assert base64.b64decode(decoded["shader_code"])[:4] == b"\x03\x02\x23\x07"
34+
assert decoded["input_0_type"] == "Tensor"
35+
assert decoded["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
36+
assert decoded["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
37+
assert decoded["input_0_binding"] == 0
38+
assert decoded["input_1_type"] == "Tensor"
39+
assert decoded["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
40+
assert decoded["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
41+
assert decoded["input_1_binding"] == 1
42+
assert decoded["output_0_type"] == "Tensor"
43+
assert decoded["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
44+
assert decoded["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
45+
assert decoded["output_0_binding"] == 2
46+
47+
48+
def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv():
49+
payload = build_grid_sampler_2d_payload(
50+
interpolation_mode=0,
51+
padding_mode=0,
52+
align_corners=False,
53+
)
54+
55+
shader_binary = base64.b64decode(payload["shader_code"])
56+
57+
assert payload["shader_language"] == "SPIR-V"
58+
assert shader_binary[:4] == b"\x03\x02\x23\x07"
59+
60+
61+
def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources():
62+
assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl"
63+
assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64"
64+
65+
66+
def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes():
67+
with pytest.raises(ValueError, match="Unsupported interpolation_mode"):
68+
build_grid_sampler_2d_payload(
69+
interpolation_mode=99,
70+
padding_mode=0,
71+
align_corners=False,
72+
)
73+
74+
with pytest.raises(ValueError, match="Unsupported padding_mode"):
75+
build_grid_sampler_2d_payload(
76+
interpolation_mode=0,
77+
padding_mode=99,
78+
align_corners=False,
79+
)

backends/arm/test/misc/test_extract_io_params_tosa.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
import torch
10+
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
1011
from executorch.backends.arm.quantizer import VgfQuantizer
1112
from executorch.backends.arm.quantizer.arm_quantizer import (
1213
get_symmetric_quantization_config,
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1920
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
2021
from executorch.exir import to_edge_transform_and_lower
22+
from executorch.exir.dialects._ops import ops as exir_ops
2123
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
2224
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2325

@@ -88,3 +90,26 @@ def test_roundtrip_extracts_io_params_tosa_INT(
8890
assert isinstance(out_name, str)
8991
assert isinstance(out_params["scale"], float)
9092
assert isinstance(out_params["zero_point"], int)
93+
94+
95+
def test_only_vgf_partitioner_registers_grid_sampler_no_target_custom_partition_op():
96+
tosa_partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))
97+
vgf_partitioner = VgfPartitioner(VgfCompileSpec("TOSA-1.0+FP"))
98+
ethosu_partitioner = EthosUPartitioner(EthosUCompileSpec("ethos-u55-128"))
99+
100+
assert hasattr(tosa_partitioner, "_custom_partition_ops")
101+
assert hasattr(vgf_partitioner, "_custom_partition_ops")
102+
assert hasattr(ethosu_partitioner, "_custom_partition_ops")
103+
104+
assert (
105+
exir_ops.edge.aten.grid_sampler_2d.default
106+
not in tosa_partitioner._custom_partition_ops
107+
)
108+
assert (
109+
exir_ops.edge.aten.grid_sampler_2d.default
110+
in vgf_partitioner._custom_partition_ops
111+
)
112+
assert (
113+
exir_ops.edge.aten.grid_sampler_2d.default
114+
not in ethosu_partitioner._custom_partition_ops
115+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import VgfPipeline
12+
13+
input_t = Tuple[torch.Tensor, torch.Tensor]
14+
aten_op = "torch.ops.aten.grid_sampler.default"
15+
exir_op = "executorch_exir_dialects_edge__ops_aten_grid_sampler_2d_default"
16+
17+
test_data_suite = {
18+
"2d_bilinear_zeros": lambda: (
19+
torch.randn(1, 3, 8, 8),
20+
torch.randn(1, 4, 4, 2),
21+
),
22+
}
23+
24+
xfails = {
25+
"2d_bilinear_zeros": (
26+
"CI model_converter does not yet include Vulkan custom-shader "
27+
"tosa.custom legalization",
28+
RuntimeError,
29+
),
30+
}
31+
32+
33+
class GridSampler2d(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
self.interpolation_mode_ = 0
37+
self.padding_mode_ = 0
38+
self.align_corners_ = False
39+
40+
def forward(self, x, grid):
41+
return F.grid_sample(
42+
x,
43+
grid,
44+
mode="bilinear" if self.interpolation_mode_ == 0 else "nearest",
45+
padding_mode="zeros" if self.padding_mode_ == 0 else "border",
46+
align_corners=self.align_corners_,
47+
)
48+
49+
50+
@common.parametrize("test_data", test_data_suite, xfails=xfails, strict=False)
51+
@common.SkipIfNoModelConverter
52+
def test_grid_sampler_vgf_no_quant(test_data):
53+
test_data = test_data()
54+
pipeline = VgfPipeline[input_t](
55+
GridSampler2d(),
56+
test_data,
57+
aten_op,
58+
exir_op,
59+
quantize=False,
60+
run_on_vulkan_runtime=False,
61+
)
62+
pipeline.run()

0 commit comments

Comments
 (0)