Skip to content

Commit 6f85b45

Browse files
Arm backend: Disable fusing of TOSA ops
Disable fusing of ops that have symbolic shapes as arguments. Also disable fusing of TOSA dialect ops. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: Ib53f17dcace6c17fa0978e234d8490d16aee5ecc
1 parent 9ce1b55 commit 6f85b45

2 files changed

Lines changed: 116 additions & 9 deletions

File tree

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from typing import Set, Type
7+
from collections.abc import Mapping
8+
from typing import Sequence, Set, Type
89

910
import torch._export.utils
1011
import torch.fx
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1920
FuseEqualPlaceholdersPass,
2021
)
22+
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
2123
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2224
from executorch.backends.transforms.utils import (
2325
create_constant_placeholder,
@@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5355
super().__init__(*args, **kwargs)
5456
self.exported_program = exported_program
5557

58+
@staticmethod
59+
def _is_tosa_dialect_op(target) -> bool:
60+
target_str = str(target)
61+
return (
62+
"executorch.exir.dialects.backend._ops.tosa." in target_str
63+
or "<EdgeOpOverload: tosa." in target_str
64+
)
65+
66+
@staticmethod
67+
def _arg_contains_symbolic_shape(arg) -> bool:
68+
if isinstance(arg, torch.fx.Node):
69+
if meta_has_shape_mark(arg.meta):
70+
return True
71+
return FuseConstantArgsPass._arg_contains_symbolic_shape(
72+
arg.meta.get("val")
73+
)
74+
if isinstance(arg, torch.SymInt):
75+
return True
76+
if isinstance(arg, Mapping):
77+
return any(
78+
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
79+
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
80+
for k, v in arg.items()
81+
)
82+
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
83+
return any(
84+
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
85+
)
86+
return False
87+
5688
def _propagate_special_dtype(self, from_nodes, to_node, data):
5789
"""Propagate special dtype meta if it exists."""
5890
special_dtypes = set()
@@ -139,13 +171,13 @@ def call(self, graph_module):
139171
for node in graph_module.graph.nodes:
140172
if node.op != "call_function":
141173
continue
142-
if node.target in [
143-
exir_ops.backend.tosa.MATMUL.default,
144-
exir_ops.backend.tosa.RESCALE.default,
145-
exir_ops.backend.tosa.RESIZE.default,
146-
exir_ops.backend.tosa.TABLE.default,
147-
exir_ops.backend.tosa.TRANSPOSE.default,
148-
]:
174+
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
175+
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
176+
if (
177+
self._is_tosa_dialect_op(node.target)
178+
or self._arg_contains_symbolic_shape(node.args)
179+
or self._arg_contains_symbolic_shape(node.kwargs)
180+
):
149181
continue
150182

151183
input_nodes = node.all_input_nodes
@@ -161,7 +193,6 @@ def call(self, graph_module):
161193
)
162194
if not all(input_nodes_constant):
163195
continue
164-
165196
try:
166197
did_fuse = self._fuse_nodes(node)
167198
if did_fuse:

backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,22 @@
66
import operator
77
from typing import cast, ClassVar, Dict, Protocol, Tuple
88

9+
import executorch.backends.arm.tosa.dialect # noqa: F401
910
import torch
1011
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
1112
ComputeConstantOpsAOTPass,
1213
FuseConstantArgsPass,
1314
)
1415
from executorch.backends.arm.test import common
1516
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
17+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
18+
from executorch.backends.arm.tosa.specification import (
19+
TosaLoweringContext,
20+
TosaSpecification,
21+
)
22+
from executorch.backends.test.program_builder import ProgramBuilder
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from torch.export.graph_signature import InputKind
1625

1726
input_t = Tuple[torch.Tensor] # Input x
1827
input_t2 = Tuple[torch.Tensor, torch.Tensor]
@@ -174,3 +183,70 @@ def test_fuse_constant_args_tosa_INT_cat(module: ModuleWithFuseAttrs) -> None:
174183
],
175184
)
176185
pipeline.run()
186+
187+
188+
def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None:
189+
class FakeTosaTarget:
190+
def __str__(self) -> str:
191+
return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default"
192+
193+
assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget())
194+
assert FuseConstantArgsPass._is_tosa_dialect_op(
195+
exir_ops.backend.tosa.GATHER.default
196+
)
197+
assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor)
198+
199+
200+
def test_fuse_constant_args_identifies_symbolic_shape_args() -> None:
201+
graph = torch.fx.Graph()
202+
shape_node = graph.placeholder("shape")
203+
shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
204+
205+
assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2]))
206+
assert not FuseConstantArgsPass._arg_contains_symbolic_shape(
207+
([1, 2], {"pad": (0, 0)})
208+
)
209+
210+
211+
def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None:
212+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
213+
builder = ProgramBuilder()
214+
values = builder.placeholder(
215+
"values",
216+
torch.randn(1, 4, 3),
217+
input_kind=InputKind.CONSTANT_TENSOR,
218+
)
219+
indices = builder.placeholder(
220+
"indices",
221+
torch.tensor([[0, 2]], dtype=torch.int32),
222+
input_kind=InputKind.CONSTANT_TENSOR,
223+
)
224+
gather = builder.call_operator(
225+
exir_ops.backend.tosa.GATHER.default,
226+
(values, indices),
227+
)
228+
builder.output([gather])
229+
230+
exported_program = builder.get_program()
231+
graph_module = exported_program.graph_module
232+
233+
with caplog.at_level("WARNING"):
234+
FuseConstantArgsPass(exported_program)(graph_module)
235+
236+
warning_messages = [
237+
record.getMessage()
238+
for record in caplog.records
239+
if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass"
240+
]
241+
assert not any(
242+
"Failed to fuse constant op" in message and "GATHER" in message
243+
for message in warning_messages
244+
)
245+
assert (
246+
sum(
247+
node.op == "call_function"
248+
and node.target == exir_ops.backend.tosa.GATHER.default
249+
for node in graph_module.graph.nodes
250+
)
251+
== 1
252+
)

0 commit comments

Comments
 (0)