Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/arm/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,27 @@ runtime.python_library(
name = "vgf",
srcs = [
"vgf/__init__.py",
"vgf/_passes/__init__.py",
"vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py",
"vgf/backend.py",
"vgf/compile_spec.py",
"vgf/model_converter.py",
"vgf/partitioner.py",
"vgf/shaders/__init__.py",
"vgf/shaders/grid_sampler.py",
],
resources = [
"vgf/shaders/grid_sampler.glsl",
"vgf/shaders/grid_sampler.spirv.b64",
],
deps = [
":arm_compile_spec",
"//caffe2:torch",
"//executorch/backends/arm/_passes:passes",
"//executorch/backends/arm/tosa/dialect:lib",
"//executorch/backends/arm/tosa:specification",
"//executorch/backends/arm/tosa:partitioner",
"//executorch/exir:lib",
],
)

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/ethosu/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from typing import final, Optional, Sequence

import torch
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.exir.backend.partitioner import DelegationSpec
from torch._ops import OpOverload
from torch.fx.passes.operator_support import OperatorSupportBase


Expand All @@ -33,5 +33,5 @@ def __init__(
)
self.additional_checks = additional_checks
self.tosa_spec = compile_spec.tosa_spec
self._custom_partition_ops: set[torch._ops.OpOverload] = set()
self._custom_partition_ops: set[OpOverload] = set()
self.intermediate_path = compile_spec._get_intermediate_path()
75 changes: 75 additions & 0 deletions backends/arm/scripts/generate_grid_sampler_spirv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import base64
import shutil
import subprocess # nosec B404 - required to invoke the shader compiler.
import tempfile
from pathlib import Path


SHADER_DIR = Path(__file__).resolve().parents[1] / "vgf" / "shaders"
DEFAULT_SOURCE = SHADER_DIR / "grid_sampler.glsl"
DEFAULT_OUTPUT = SHADER_DIR / "grid_sampler.spirv.b64"


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Compile the VGF grid_sampler GLSL shader to SPIR-V and write the "
"base64-encoded payload consumed by the ExecuTorch custom-shader "
"lowering."
)
)
parser.add_argument(
"--source",
type=Path,
default=DEFAULT_SOURCE,
help=f"GLSL source file. Defaults to {DEFAULT_SOURCE}",
)
parser.add_argument(
"--output",
type=Path,
default=DEFAULT_OUTPUT,
help=f"Base64 SPIR-V output file. Defaults to {DEFAULT_OUTPUT}",
)
parser.add_argument(
"--glslc",
default="glslc",
help="Path to glslc. Defaults to resolving glslc from PATH.",
)
return parser.parse_args()


def _resolve_glslc(glslc: str) -> str:
resolved = shutil.which(glslc)
if resolved is None:
raise RuntimeError(
f"Could not find {glslc}. Install the Vulkan SDK or pass --glslc."
)
return resolved


def _write_base64_spirv(spirv_path: Path, output_path: Path) -> None:
encoded = base64.b64encode(spirv_path.read_bytes()).decode("ascii")
output_path.write_text(encoded + "\n", encoding="utf-8")


def main() -> None:
args = _parse_args()
glslc = _resolve_glslc(args.glslc)

with tempfile.TemporaryDirectory() as tmpdir:
spirv_path = Path(tmpdir) / "grid_sampler.spirv"
subprocess.run( # nosec B603 - glslc path is resolved explicitly.
[glslc, str(args.source), "-o", str(spirv_path)],
check=True,
)
_write_base64_spirv(spirv_path, args.output)


if __name__ == "__main__":
main()
79 changes: 79 additions & 0 deletions backends/arm/test/misc/test_custom_shader_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import base64

import pytest
from executorch.backends.arm.vgf.shaders.grid_sampler import (
build_grid_sampler_2d_payload,
decode_payload,
encode_payload,
GRID_SAMPLER_2D_SHADER_BINARY,
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
GRID_SAMPLER_2D_SHADER_LANGUAGE,
GRID_SAMPLER_2D_SHADER_SOURCE,
GRID_SAMPLER_2D_VK_FORMAT,
GRID_SAMPLER_2D_WORKGROUP_SIZES,
)


def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip():
payload = build_grid_sampler_2d_payload(
interpolation_mode=0,
padding_mode=2,
align_corners=True,
)
decoded = decode_payload(encode_payload(payload))

assert decoded["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
assert decoded["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
assert decoded["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
assert base64.b64decode(decoded["shader_code"])[:4] == b"\x03\x02\x23\x07"
assert decoded["input_0_type"] == "Tensor"
assert decoded["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
assert decoded["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
assert decoded["input_0_binding"] == 0
assert decoded["input_1_type"] == "Tensor"
assert decoded["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
assert decoded["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
assert decoded["input_1_binding"] == 1
assert decoded["output_0_type"] == "Tensor"
assert decoded["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
assert decoded["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
assert decoded["output_0_binding"] == 2


def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv():
payload = build_grid_sampler_2d_payload(
interpolation_mode=0,
padding_mode=0,
align_corners=False,
)

shader_binary = base64.b64decode(payload["shader_code"])

assert payload["shader_language"] == "SPIR-V"
assert shader_binary[:4] == b"\x03\x02\x23\x07"


def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources():
assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl"
assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64"


def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes():
with pytest.raises(ValueError, match="Unsupported interpolation_mode"):
build_grid_sampler_2d_payload(
interpolation_mode=99,
padding_mode=0,
align_corners=False,
)

with pytest.raises(ValueError, match="Unsupported padding_mode"):
build_grid_sampler_2d_payload(
interpolation_mode=0,
padding_mode=99,
align_corners=False,
)
25 changes: 25 additions & 0 deletions backends/arm/test/misc/test_extract_io_params_tosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
import torch
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import VgfQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_quantization_config,
Expand All @@ -18,6 +19,7 @@
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

Expand Down Expand Up @@ -88,3 +90,26 @@ def test_roundtrip_extracts_io_params_tosa_INT(
assert isinstance(out_name, str)
assert isinstance(out_params["scale"], float)
assert isinstance(out_params["zero_point"], int)


def test_only_vgf_partitioner_registers_grid_sampler_no_target_custom_partition_op():
tosa_partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))
vgf_partitioner = VgfPartitioner(VgfCompileSpec("TOSA-1.0+FP"))
ethosu_partitioner = EthosUPartitioner(EthosUCompileSpec("ethos-u55-128"))

assert hasattr(tosa_partitioner, "_custom_partition_ops")
assert hasattr(vgf_partitioner, "_custom_partition_ops")
assert hasattr(ethosu_partitioner, "_custom_partition_ops")

assert (
exir_ops.edge.aten.grid_sampler_2d.default
not in tosa_partitioner._custom_partition_ops
)
assert (
exir_ops.edge.aten.grid_sampler_2d.default
in vgf_partitioner._custom_partition_ops
)
assert (
exir_ops.edge.aten.grid_sampler_2d.default
not in ethosu_partitioner._custom_partition_ops
)
62 changes: 62 additions & 0 deletions backends/arm/test/ops/test_grid_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
import torch.nn.functional as F
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import VgfPipeline

input_t = Tuple[torch.Tensor, torch.Tensor]
aten_op = "torch.ops.aten.grid_sampler.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_grid_sampler_2d_default"

test_data_suite = {
"2d_bilinear_zeros": lambda: (
torch.randn(1, 3, 8, 8),
torch.randn(1, 4, 4, 2),
),
}

xfails = {
"2d_bilinear_zeros": (
"CI model_converter does not yet include Vulkan custom-shader "
"tosa.custom legalization",
RuntimeError,
),
}


class GridSampler2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.interpolation_mode_ = 0
self.padding_mode_ = 0
self.align_corners_ = False

def forward(self, x, grid):
return F.grid_sample(
x,
grid,
mode="bilinear" if self.interpolation_mode_ == 0 else "nearest",
padding_mode="zeros" if self.padding_mode_ == 0 else "border",
align_corners=self.align_corners_,
)


@common.parametrize("test_data", test_data_suite, xfails=xfails, strict=False)
@common.SkipIfNoModelConverter
def test_grid_sampler_vgf_no_quant(test_data):
test_data = test_data()
pipeline = VgfPipeline[input_t](
GridSampler2d(),
test_data,
aten_op,
exir_op,
quantize=False,
run_on_vulkan_runtime=False,
)
pipeline.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
import torch
import torch.nn.functional as F
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.backends.arm.vgf._passes.rewrite_grid_sampler_to_tosa_custom import (
RewriteGridSamplerToTosaCustomPass,
)
from executorch.backends.arm.vgf.shaders.grid_sampler import (
CUSTOM_SHADER_DOMAIN_NAME,
decode_payload,
GRID_SAMPLER_2D_OPERATOR_NAME,
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
GRID_SAMPLER_2D_SHADER_LANGUAGE,
GRID_SAMPLER_2D_VK_FORMAT,
GRID_SAMPLER_2D_WORKGROUP_SIZES,
)
from executorch.exir import to_edge
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import export


class GridSampler2d(torch.nn.Module):
def __init__(self):
super().__init__()
self.interpolation_mode_ = 0
self.padding_mode_ = 0
self.align_corners_ = False

def forward(self, x, grid):
return F.grid_sample(
x,
grid,
mode="bilinear" if self.interpolation_mode_ == 0 else "nearest",
padding_mode="zeros" if self.padding_mode_ == 0 else "border",
align_corners=self.align_corners_,
)


def test_rewrite_grid_sampler_to_tosa_custom_no_target():
model = GridSampler2d()
example_inputs = (
torch.randn(1, 3, 8, 8),
torch.randn(1, 4, 4, 2),
)

edge_model = to_edge(export(model, example_inputs))
nodes = list(edge_model.exported_program().graph.nodes)

assert any(
node.target == exir_ops.edge.aten.grid_sampler_2d.default for node in nodes
)

with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")):
edge_model = edge_model.transform([RewriteGridSamplerToTosaCustomPass()])
nodes = list(edge_model.exported_program().graph.nodes)

assert not any(
node.target == exir_ops.edge.aten.grid_sampler_2d.default for node in nodes
)

custom_node = next(
node for node in nodes if node.target == exir_ops.backend.tosa.CUSTOM.default
)
assert custom_node.kwargs["operator_name"] == GRID_SAMPLER_2D_OPERATOR_NAME
assert custom_node.kwargs["domain_name"] == CUSTOM_SHADER_DOMAIN_NAME

payload = decode_payload(custom_node.kwargs["implementation_attrs"])
assert payload["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
assert payload["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
assert payload["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
assert payload["input_0_type"] == "Tensor"
assert payload["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
assert payload["input_0_binding"] == 0
assert payload["input_0_descriptorset"] == 0
assert payload["input_1_type"] == "Tensor"
assert payload["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
assert payload["input_1_binding"] == 1
assert payload["input_1_descriptorset"] == 0
assert payload["output_0_type"] == "Tensor"
assert payload["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
assert payload["output_0_binding"] == 2
assert payload["output_0_descriptorset"] == 0
Loading
Loading