diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index d6fd4b18b53..f54ed851240 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Set, Type +from collections.abc import Mapping +from typing import Sequence, Set, Type import torch._export.utils import torch.fx @@ -18,6 +19,7 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) +from executorch.backends.arm.tosa.dialect.shape import meta_has_shape_mark from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, @@ -53,6 +55,36 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.exported_program = exported_program + @staticmethod + def _is_tosa_dialect_op(target) -> bool: + target_str = str(target) + return ( + "executorch.exir.dialects.backend._ops.tosa." in target_str + or " bool: + if isinstance(arg, torch.fx.Node): + if meta_has_shape_mark(arg.meta): + return True + return FuseConstantArgsPass._arg_contains_symbolic_shape( + arg.meta.get("val") + ) + if isinstance(arg, torch.SymInt): + return True + if isinstance(arg, Mapping): + return any( + FuseConstantArgsPass._arg_contains_symbolic_shape(k) + or FuseConstantArgsPass._arg_contains_symbolic_shape(v) + for k, v in arg.items() + ) + if isinstance(arg, Sequence) and not isinstance(arg, (str, bytes)): + return any( + FuseConstantArgsPass._arg_contains_symbolic_shape(v) for v in arg + ) + return False + def _propagate_special_dtype(self, from_nodes, to_node, data): """Propagate special dtype meta if it exists.""" special_dtypes = set() @@ -142,13 +174,13 @@ def call(self, graph_module): for node in graph_module.graph.nodes: if node.op != "call_function": continue - if node.target in [ - exir_ops.backend.tosa.MATMUL.default, - exir_ops.backend.tosa.RESCALE.default, - exir_ops.backend.tosa.RESIZE.default, - exir_ops.backend.tosa.TABLE.default, - exir_ops.backend.tosa.TRANSPOSE.default, - ]: + # Don't fuse TOSA dialect ops as they do not have eager forward functions. + # Also don't fuse ops whose explicit args/kwargs include symbolic shape values. + if ( + self._is_tosa_dialect_op(node.target) + or self._arg_contains_symbolic_shape(node.args) + or self._arg_contains_symbolic_shape(node.kwargs) + ): continue input_nodes = node.all_input_nodes @@ -164,7 +196,6 @@ def call(self, graph_module): ) if not all(input_nodes_constant): continue - try: did_fuse = self._fuse_nodes(node) if did_fuse: diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index 785744c1b37..d915b4ecba0 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -6,6 +6,7 @@ import operator from typing import cast, ClassVar, Dict, Protocol, Tuple +import executorch.backends.arm.tosa.dialect # noqa: F401 import torch from executorch.backends.arm._passes.fuse_constant_ops_pass import ( ComputeConstantOpsAOTPass, @@ -15,8 +16,15 @@ from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import PassPipeline -from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) from executorch.backends.test.harness.stages import StageType +from executorch.backends.test.program_builder import ProgramBuilder +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind input_t = Tuple[torch.Tensor] # Input x input_t2 = Tuple[torch.Tensor, torch.Tensor] @@ -270,3 +278,70 @@ def test_fuse_constant_args_tosa_INT_cat_uses_top_level_arg_qparams() -> None: for node in pass_result.graph_module.graph.nodes if node.op == "placeholder" ] == ["aten_cat_default_fused_const"] + + +def test_fuse_constant_args_identifies_tosa_dialect_targets() -> None: + class FakeTosaTarget: + def __str__(self) -> str: + return "executorch.exir.dialects.backend._ops.tosa.MAX_POOL2D.default" + + assert FuseConstantArgsPass._is_tosa_dialect_op(FakeTosaTarget()) + assert FuseConstantArgsPass._is_tosa_dialect_op( + exir_ops.backend.tosa.GATHER.default + ) + assert not FuseConstantArgsPass._is_tosa_dialect_op(torch.ops.aten.add.Tensor) + + +def test_fuse_constant_args_identifies_symbolic_shape_args() -> None: + graph = torch.fx.Graph() + shape_node = graph.placeholder("shape") + shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE + + assert FuseConstantArgsPass._arg_contains_symbolic_shape((shape_node, [1, 2])) + assert not FuseConstantArgsPass._arg_contains_symbolic_shape( + ([1, 2], {"pad": (0, 0)}) + ) + + +def test_fuse_constant_args_skips_backend_tosa_gather(caplog) -> None: + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.1+FP+shape")): + builder = ProgramBuilder() + values = builder.placeholder( + "values", + torch.randn(1, 4, 3), + input_kind=InputKind.CONSTANT_TENSOR, + ) + indices = builder.placeholder( + "indices", + torch.tensor([[0, 2]], dtype=torch.int32), + input_kind=InputKind.CONSTANT_TENSOR, + ) + gather = builder.call_operator( + exir_ops.backend.tosa.GATHER.default, + (values, indices), + ) + builder.output([gather]) + + exported_program = builder.get_program() + graph_module = exported_program.graph_module + + with caplog.at_level("WARNING"): + FuseConstantArgsPass(exported_program)(graph_module) + + warning_messages = [ + record.getMessage() + for record in caplog.records + if record.name == "executorch.backends.arm._passes.fuse_constant_ops_pass" + ] + assert not any( + "Failed to fuse constant op" in message and "GATHER" in message + for message in warning_messages + ) + assert ( + sum( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.GATHER.default + for node in graph_module.graph.nodes + ) + == 1 + )