Skip to content

Commit 26a00cb

Browse files
authored
Merge branch 'main' into main
2 parents a08a035 + d98aa22 commit 26a00cb

39 files changed

Lines changed: 1636 additions & 205 deletions

backends/arm/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,27 @@ runtime.python_library(
8787
name = "vgf",
8888
srcs = [
8989
"vgf/__init__.py",
90+
"vgf/_passes/__init__.py",
91+
"vgf/_passes/rewrite_grid_sampler_to_tosa_custom.py",
9092
"vgf/backend.py",
9193
"vgf/compile_spec.py",
9294
"vgf/model_converter.py",
9395
"vgf/partitioner.py",
96+
"vgf/shaders/__init__.py",
97+
"vgf/shaders/grid_sampler.py",
98+
],
99+
resources = [
100+
"vgf/shaders/grid_sampler.glsl",
101+
"vgf/shaders/grid_sampler.spirv.b64",
94102
],
95103
deps = [
96104
":arm_compile_spec",
105+
"//caffe2:torch",
106+
"//executorch/backends/arm/_passes:passes",
107+
"//executorch/backends/arm/tosa/dialect:lib",
97108
"//executorch/backends/arm/tosa:specification",
98109
"//executorch/backends/arm/tosa:partitioner",
110+
"//executorch/exir:lib",
99111
],
100112
)
101113

backends/arm/_passes/convert_int64_output_ops_to_int32.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
import logging
8-
from typing import Set, Type
8+
from typing import cast, Literal, Set, Type
99

1010
import torch
1111
from executorch.backends.arm._passes import ArmPass
@@ -25,26 +25,54 @@ class ConvertInt64OutputOpsToInt32Pass(ArmPass):
2525
"""Rewrites or removes operations that produce int64 outputs, converting
2626
them to int32 where possible.
2727
28-
Currently, this pass handles casting and argmax operators:
28+
Currently, this pass handles casting, argmax and argmin operators:
2929
1. int32 -> int64:
3030
removes the cast and redirects all uses to the original int32 value.
3131
2. other types -> int64:
3232
rewrites the cast to produce int32 instead of int64.
33-
3. torch.argmax()
34-
insert an int64->int32 cast after the argmax node
33+
3. torch.argmax() / torch.argmin()
34+
insert an int64->int32 cast after the argmax/argmin node
3535
36-
Future extensions may include operators that return int64 outputs by default
37-
(e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
38-
int32 results.
36+
Future extensions may include other operators that return int64 outputs by
37+
default, rewriting them or inserting an int64 -> int32 cast to yield int32
38+
results.
3939
40-
Note: Overflow checks are applied selectively in this pass. For operators without
41-
such checks, it is the user's responsibility to ensure that values fit within
42-
the int32 range.
40+
Args:
41+
on_overflow: Action when an argmax/argmin index cannot safely fit in
42+
int32 (i.e. the reduced dimension has more than INT32_MAX elements).
43+
``"raise"`` (default) raises a ``RuntimeError`` at compile time.
44+
``"warn"`` logs a warning and skips the cast for that node.
45+
``"skip"`` silently skips the cast for that node.
4346
4447
"""
4548

4649
_passes_required_after: Set[Type[ExportPass]] = set()
4750

51+
_INT32_MAX = torch.iinfo(torch.int32).max
52+
53+
def __init__(
54+
self,
55+
*args,
56+
on_overflow: Literal["raise", "warn", "skip"] = "raise",
57+
**kwargs,
58+
) -> None:
59+
super().__init__(*args, **kwargs)
60+
if on_overflow not in ("raise", "warn", "skip"):
61+
raise ValueError(
62+
f"on_overflow must be 'raise', 'warn', or 'skip', got {on_overflow!r}"
63+
)
64+
self.on_overflow = on_overflow
65+
66+
def _is_int32_range_safe(self, node: torch.fx.Node) -> bool:
67+
"""Return True if the argmax/argmin index output fits in int32."""
68+
input_tensor = get_first_fake_tensor(cast(torch.fx.Node, node.args[0]))
69+
dim = node.args[1] if len(node.args) > 1 and node.args[1] is not None else None
70+
if dim is None:
71+
size = input_tensor.numel()
72+
else:
73+
size = input_tensor.shape[cast(int, dim)]
74+
return size <= self._INT32_MAX
75+
4876
aten_cast_ops = (
4977
torch.ops.aten.to.dtype,
5078
torch.ops.aten.to.dtype_layout,
@@ -54,8 +82,11 @@ class ConvertInt64OutputOpsToInt32Pass(ArmPass):
5482
aten_argmax_ops = (torch.ops.aten.argmax.default,)
5583
edge_argmax_ops = (exir_ops.edge.aten.argmax.default,)
5684

57-
aten_ops = aten_cast_ops + aten_argmax_ops
58-
edge_ops = edge_cast_ops + edge_argmax_ops
85+
aten_argmin_ops = (torch.ops.aten.argmin.default,)
86+
edge_argmin_ops = (exir_ops.edge.aten.argmin.default,)
87+
88+
aten_ops = aten_cast_ops + aten_argmax_ops + aten_argmin_ops
89+
edge_ops = edge_cast_ops + edge_argmax_ops + edge_argmin_ops
5990

6091
# dtype is specified in args
6192
cast_ops_args = (
@@ -104,7 +135,7 @@ def _convert_casting_operators(self, node: torch.fx.Node):
104135
f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}"
105136
)
106137

107-
def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph):
138+
def _cast_int64_output_to_int32(self, node: torch.fx.Node, graph: torch.fx.Graph):
108139
output_tensor = node
109140
to_copy_op = self._get_decomposition(node.target)
110141
with graph.inserting_after(node):
@@ -138,9 +169,23 @@ def call(self, graph_module: torch.fx.GraphModule):
138169

139170
if node.target in self.aten_cast_ops + self.edge_cast_ops:
140171
self._convert_casting_operators(node)
141-
elif node.target in self.aten_argmax_ops + self.edge_argmax_ops:
142-
# TODO: Add range check based on the input tensor shape before casting the output
143-
self._convert_argmax_operators(node, graph)
172+
elif node.target in (
173+
self.aten_argmax_ops
174+
+ self.edge_argmax_ops
175+
+ self.aten_argmin_ops
176+
+ self.edge_argmin_ops
177+
):
178+
if not self._is_int32_range_safe(node):
179+
msg = (
180+
f"{node.target} reduces over more than {self._INT32_MAX} elements; "
181+
f"the int64 index cannot be safely cast to int32."
182+
)
183+
if self.on_overflow == "raise":
184+
raise RuntimeError(msg)
185+
if self.on_overflow == "warn":
186+
logger.warning(msg)
187+
continue
188+
self._cast_int64_output_to_int32(node, graph)
144189
else:
145190
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
146191

backends/arm/_passes/decompose_grouped_conv_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,10 @@ def call_operator(self, op, args, kwargs, meta):
257257

258258
input_node = args[0]
259259
if DecomposeGroupedConvPass._is_depthwise_conv(input_node, groups, transposed):
260-
# This is a depthwise convolution which is handled elsewhere
261-
return super().call_operator(op, args, kwargs, meta)
260+
# Conv2D depthwise maps to TOSA DEPTHWISE_CONV2D — handled in RewriteConvPass.
261+
# Conv3D has no DEPTHWISE_CONV3D, so fall through and decompose like grouped conv.
262+
if len(input_node.data.shape) != 5:
263+
return super().call_operator(op, args, kwargs, meta)
262264

263265
weight_node = args[1]
264266
bias_node = args[2]

backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,13 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
129129

130130
def _is_conv3d(self, rank, groups) -> bool:
131131
if rank == 5:
132-
# A Conv3D is considered depthwise if Group == InChannels and
133-
# Group * N == OutChannels, where N is a possitive integer.
134-
# Currently we do not support depthwise or grouped conv3d.
135-
# @TODO Add grouped/depthwise conv3d support or reject in partitioner.
132+
# Both grouped and depthwise Conv3D are decomposed into groups==1
133+
# convolutions by DecomposeGroupedConvPass before reaching here.
134+
# This guard is defense-in-depth for paths that bypass that pass.
136135
if groups != 1:
137136
raise RuntimeError(
138-
"CONV3D with groups != 1 is not supported in the Arm backend."
137+
"CONV3D with groups != 1 reached unexpectedly; "
138+
"DecomposeGroupedConvPass should have decomposed it first."
139139
)
140140
return True
141141
return False

backends/arm/ethosu/partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from typing import final, Optional, Sequence
77

8-
import torch
98
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
109
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1110
from executorch.exir.backend.partitioner import DelegationSpec
11+
from torch._ops import OpOverload
1212
from torch.fx.passes.operator_support import OperatorSupportBase
1313

1414

@@ -33,5 +33,5 @@ def __init__(
3333
)
3434
self.additional_checks = additional_checks
3535
self.tosa_spec = compile_spec.tosa_spec
36-
self._custom_partition_ops: set[torch._ops.OpOverload] = set()
36+
self._custom_partition_ops: set[OpOverload] = set()
3737
self.intermediate_path = compile_spec._get_intermediate_path()
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import argparse
7+
import base64
8+
import shutil
9+
import subprocess # nosec B404 - required to invoke the shader compiler.
10+
import tempfile
11+
from pathlib import Path
12+
13+
14+
SHADER_DIR = Path(__file__).resolve().parents[1] / "vgf" / "shaders"
15+
DEFAULT_SOURCE = SHADER_DIR / "grid_sampler.glsl"
16+
DEFAULT_OUTPUT = SHADER_DIR / "grid_sampler.spirv.b64"
17+
18+
19+
def _parse_args() -> argparse.Namespace:
20+
parser = argparse.ArgumentParser(
21+
description=(
22+
"Compile the VGF grid_sampler GLSL shader to SPIR-V and write the "
23+
"base64-encoded payload consumed by the ExecuTorch custom-shader "
24+
"lowering."
25+
)
26+
)
27+
parser.add_argument(
28+
"--source",
29+
type=Path,
30+
default=DEFAULT_SOURCE,
31+
help=f"GLSL source file. Defaults to {DEFAULT_SOURCE}",
32+
)
33+
parser.add_argument(
34+
"--output",
35+
type=Path,
36+
default=DEFAULT_OUTPUT,
37+
help=f"Base64 SPIR-V output file. Defaults to {DEFAULT_OUTPUT}",
38+
)
39+
parser.add_argument(
40+
"--glslc",
41+
default="glslc",
42+
help="Path to glslc. Defaults to resolving glslc from PATH.",
43+
)
44+
return parser.parse_args()
45+
46+
47+
def _resolve_glslc(glslc: str) -> str:
48+
resolved = shutil.which(glslc)
49+
if resolved is None:
50+
raise RuntimeError(
51+
f"Could not find {glslc}. Install the Vulkan SDK or pass --glslc."
52+
)
53+
return resolved
54+
55+
56+
def _write_base64_spirv(spirv_path: Path, output_path: Path) -> None:
57+
encoded = base64.b64encode(spirv_path.read_bytes()).decode("ascii")
58+
output_path.write_text(encoded + "\n", encoding="utf-8")
59+
60+
61+
def main() -> None:
62+
args = _parse_args()
63+
glslc = _resolve_glslc(args.glslc)
64+
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
spirv_path = Path(tmpdir) / "grid_sampler.spirv"
67+
subprocess.run( # nosec B603 - glslc path is resolved explicitly.
68+
[glslc, str(args.source), "-o", str(spirv_path)],
69+
check=True,
70+
)
71+
_write_base64_spirv(spirv_path, args.output)
72+
73+
74+
if __name__ == "__main__":
75+
main()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import base64
7+
8+
import pytest
9+
from executorch.backends.arm.vgf.shaders.grid_sampler import (
10+
build_grid_sampler_2d_payload,
11+
decode_payload,
12+
encode_payload,
13+
GRID_SAMPLER_2D_SHADER_BINARY,
14+
GRID_SAMPLER_2D_SHADER_ENTRY_POINT,
15+
GRID_SAMPLER_2D_SHADER_LANGUAGE,
16+
GRID_SAMPLER_2D_SHADER_SOURCE,
17+
GRID_SAMPLER_2D_VK_FORMAT,
18+
GRID_SAMPLER_2D_WORKGROUP_SIZES,
19+
)
20+
21+
22+
def test_grid_sampler_2d_custom_shader_payload_no_target_round_trip():
23+
payload = build_grid_sampler_2d_payload(
24+
interpolation_mode=0,
25+
padding_mode=2,
26+
align_corners=True,
27+
)
28+
decoded = decode_payload(encode_payload(payload))
29+
30+
assert decoded["entry_point"] == GRID_SAMPLER_2D_SHADER_ENTRY_POINT
31+
assert decoded["workgroup_sizes"] == GRID_SAMPLER_2D_WORKGROUP_SIZES
32+
assert decoded["shader_language"] == GRID_SAMPLER_2D_SHADER_LANGUAGE
33+
assert base64.b64decode(decoded["shader_code"])[:4] == b"\x03\x02\x23\x07"
34+
assert decoded["input_0_type"] == "Tensor"
35+
assert decoded["input_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
36+
assert decoded["input_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
37+
assert decoded["input_0_binding"] == 0
38+
assert decoded["input_1_type"] == "Tensor"
39+
assert decoded["input_1_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
40+
assert decoded["input_1_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
41+
assert decoded["input_1_binding"] == 1
42+
assert decoded["output_0_type"] == "Tensor"
43+
assert decoded["output_0_vkformat"] == GRID_SAMPLER_2D_VK_FORMAT
44+
assert decoded["output_0_vkdescriptortype"] == "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER"
45+
assert decoded["output_0_binding"] == 2
46+
47+
48+
def test_grid_sampler_2d_custom_shader_payload_no_target_uses_spirv():
49+
payload = build_grid_sampler_2d_payload(
50+
interpolation_mode=0,
51+
padding_mode=0,
52+
align_corners=False,
53+
)
54+
55+
shader_binary = base64.b64decode(payload["shader_code"])
56+
57+
assert payload["shader_language"] == "SPIR-V"
58+
assert shader_binary[:4] == b"\x03\x02\x23\x07"
59+
60+
61+
def test_grid_sampler_2d_custom_shader_payload_no_target_has_shader_resources():
62+
assert GRID_SAMPLER_2D_SHADER_SOURCE == "grid_sampler.glsl"
63+
assert GRID_SAMPLER_2D_SHADER_BINARY == "grid_sampler.spirv.b64"
64+
65+
66+
def test_grid_sampler_2d_custom_shader_payload_no_target_rejects_bad_modes():
67+
with pytest.raises(ValueError, match="Unsupported interpolation_mode"):
68+
build_grid_sampler_2d_payload(
69+
interpolation_mode=99,
70+
padding_mode=0,
71+
align_corners=False,
72+
)
73+
74+
with pytest.raises(ValueError, match="Unsupported padding_mode"):
75+
build_grid_sampler_2d_payload(
76+
interpolation_mode=0,
77+
padding_mode=99,
78+
align_corners=False,
79+
)

0 commit comments

Comments
 (0)