Skip to content

Commit 27fd39e

Browse files
Arm backend: Disable fusing of TOSA ops
Disable fusing of ops that have symbolic shapes as arguments. Also disable fusing of TOSA dialect ops. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com> Change-Id: Ib53f17dcace6c17fa0978e234d8490d16aee5ecc
1 parent eef7921 commit 27fd39e

2 files changed

Lines changed: 116 additions & 10 deletions

File tree

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from typing import Set, Type
7+
from collections.abc import Mapping
8+
from typing import Sequence, Set, Type
89

910
import torch._export.utils
1011
import torch.fx
@@ -18,6 +19,7 @@
1819
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1920
FuseEqualPlaceholdersPass,
2021
)
22+
from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark
2123
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2224
from executorch.backends.transforms.utils import (
2325
create_constant_placeholder,
@@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
5355
super().__init__(*args, **kwargs)
5456
self.exported_program = exported_program
5557

58+
@staticmethod
59+
def _is_tosa_dialect_op(target) -> bool:
60+
target_str = str(target)
61+
return (
62+
"executorch.exir.dialects.backend._ops.tosa." in target_str
63+
or "<EdgeOpOverload: tosa." in target_str
64+
)
65+
66+
@staticmethod
67+
def _arg_contains_symbolic_shape(arg) -> bool:
68+
if isinstance(arg, torch.fx.Node):
69+
if meta_has_shape_mark(arg.meta):
70+
return True
71+
return FuseConstantArgsPass._arg_contains_symbolic_shape(
72+
arg.meta.get("val")
73+
)
74+
if isinstance(arg, torch.SymInt):
75+
return True
76+
if isinstance(arg, Mapping):
77+
return any(
78+
FuseConstantArgsPass._arg_contains_symbolic_shape(k)
79+
or FuseConstantArgsPass._arg_contains_symbolic_shape(v)
80+
for k, v in arg.items()
81+
)
82+
if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)):
83+
return any(
84+
FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg
85+
)
86+
return False
87+
5688
def _propagate_special_dtype(self, from_nodes, to_node, data):
5789
"""Propagate special dtype meta if it exists."""
5890
special_dtypes = set()
@@ -142,13 +174,13 @@ def call(self, graph_module):
142174
for node in graph_module.graph.nodes:
143175
if node.op != "call_function":
144176
continue
145-
if node.target in [
146-
exir_ops.backend.tosa.MATMUL.default,
147-
exir_ops.backend.tosa.RESCALE.default,
148-
exir_ops.backend.tosa.RESIZE.default,
149-
exir_ops.backend.tosa.TABLE.default,
150-
exir_ops.backend.tosa.TRANSPOSE.default,
151-
]:
177+
# Don't fuse TOSA dialect ops as they do not have eager forward functions.
178+
# Also don't fuse ops whose explicit args/kwargs include symbolic shape values.
179+
if (
180+
self._is_tosa_dialect_op(node.target)
181+
or self._arg_contains_symbolic_shape(node.args)
182+
or self._arg_contains_symbolic_shape(node.kwargs)
183+
):
152184
continue
153185

154186
input_nodes = node.all_input_nodes
@@ -164,7 +196,6 @@ def call(self, graph_module):
164196
)
165197
if not all(input_nodes_constant):
166198
continue
167-
168199
try:
169200
did_fuse = self._fuse_nodes(node)
170201
if did_fuse:

backends/arm/test/passes/test_fuse_constant_ops_pass.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import operator
77
from typing import cast, ClassVar, Dict, Protocol, Tuple
88

9+
import executorch.backends.arm.tosa.dialect # noqa: F401
910
import torch
1011
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
1112
ComputeConstantOpsAOTPass,
@@ -15,8 +16,15 @@
1516
from executorch.backends.arm.test import common
1617
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1718
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
18-
from executorch.backends.arm.tosa import TosaSpecification
19+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
20+
from executorch.backends.arm.tosa.specification import (
21+
TosaLoweringContext,
22+
TosaSpecification,
23+
)
1924
from executorch.backends.test.harness.stages import StageType
25+
from executorch.backends.test.program_builder import ProgramBuilder
26+
from executorch.exir.dialects._ops import ops as exir_ops
27+
from torch.export.graph_signature import InputKind
2028

2129
input_t = Tuple[torch.Tensor] # Input x
2230
input_t2 = Tuple[torch.Tensor, torch.Tensor]
@@ -270,3 +278,70 @@ def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None:
270278
for node in pass_result.graph_module.graph.nodes
271279
if node.op == "placeholder"
272280
] == ["aten_cat_default_fused_const"]
281+
282+
283+
def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None:
284+
class FakeTosaTarget:
285+
def __str__(self) -> str:
286+
return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default"
287+
288+
assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget())
289+
assert FuseConstantArgsPass._is_tosa_dialect_op(
290+
exir_ops.backend.tosa.GATHER.default
291+
)
292+
assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor)
293+
294+
295+
def test_fuse_constant_args_identifies_symbolic_shape_args() -> None:
296+
graph = torch.fx.Graph()
297+
shape_node = graph.placeholder("shape")
298+
shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
299+
300+
assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2]))
301+
assert not FuseConstantArgsPass._arg_contains_symbolic_shape(
302+
([1, 2], {"pad": (0, 0)})
303+
)
304+
305+
306+
def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None:
307+
with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")):
308+
builder = ProgramBuilder()
309+
values = builder.placeholder(
310+
"values",
311+
torch.randn(1, 4, 3),
312+
input_kind=InputKind.CONSTANT_TENSOR,
313+
)
314+
indices = builder.placeholder(
315+
"indices",
316+
torch.tensor([[0, 2]], dtype=torch.int32),
317+
input_kind=InputKind.CONSTANT_TENSOR,
318+
)
319+
gather = builder.call_operator(
320+
exir_ops.backend.tosa.GATHER.default,
321+
(values, indices),
322+
)
323+
builder.output([gather])
324+
325+
exported_program = builder.get_program()
326+
graph_module = exported_program.graph_module
327+
328+
with caplog.at_level("WARNING"):
329+
FuseConstantArgsPass(exported_program)(graph_module)
330+
331+
warning_messages = [
332+
record.getMessage()
333+
for record in caplog.records
334+
if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass"
335+
]
336+
assert not any(
337+
"Failed to fuse constant op" in message and "GATHER" in message
338+
for message in warning_messages
339+
)
340+
assert (
341+
sum(
342+
node.op == "call_function"
343+
and node.target == exir_ops.backend.tosa.GATHER.default
344+
for node in graph_module.graph.nodes
345+
)
346+
== 1
347+
)

0 commit comments

Comments
 (0)