-
Notifications
You must be signed in to change notification settings - Fork 989
Arm backend: Preserve duplicate output slots with TOSA identity fanout #18866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6ae1b6a
f675b01
51dfb1b
bf44c30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from collections import Counter | ||
| from typing import Any, Set, Type | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes import ArmPass | ||
| from executorch.backends.arm._passes.arm_pass_utils import create_node | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.pass_base import ExportPass, PassResult | ||
|
|
||
|
|
||
| class EnsureUniqueOutputNodesPass(ArmPass): | ||
| """Ensure each graph output leaf references a unique producer node. | ||
|
|
||
| If the same node appears multiple times in the output structure, insert a | ||
| ``tosa.IDENTITY`` node for each occurrence and replace the repeated output | ||
| entries with those identity nodes. | ||
|
|
||
| """ | ||
|
|
||
| _passes_required_after: Set[Type[ExportPass]] = set() | ||
|
|
||
| @staticmethod | ||
| def _collect_output_nodes( | ||
| output_value: Any, counts: Counter[torch.fx.Node] | ||
| ) -> None: | ||
| if isinstance(output_value, torch.fx.Node): | ||
| counts[output_value] += 1 | ||
| return | ||
| if isinstance(output_value, (list, tuple)): | ||
| for value in output_value: | ||
| EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts) | ||
|
|
||
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
| graph = graph_module.graph | ||
| output_node = graph.output_node() | ||
| output_value = output_node.args[0] | ||
|
|
||
| counts: Counter[torch.fx.Node] = Counter() | ||
| self._collect_output_nodes(output_value, counts) | ||
| repeated_nodes = {node for node, count in counts.items() if count > 1} | ||
| if not repeated_nodes: | ||
| return PassResult(graph_module, False) | ||
|
|
||
| modified = False | ||
|
|
||
| def _replace_repeated_outputs(value: Any) -> Any: | ||
| nonlocal modified | ||
| if isinstance(value, torch.fx.Node): | ||
| if value not in repeated_nodes: | ||
| return value | ||
| with graph.inserting_before(output_node): | ||
| identity_node = create_node( | ||
| graph, | ||
| exir_ops.backend.tosa.IDENTITY.default, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for output placeholders I guess to not mess up the signature? |
||
| args=(value,), | ||
| from_node=value, | ||
| ) | ||
| modified = True | ||
| return identity_node | ||
|
|
||
| if isinstance(value, tuple): | ||
| return tuple(_replace_repeated_outputs(v) for v in value) | ||
|
|
||
| if isinstance(value, list): | ||
| return [_replace_repeated_outputs(v) for v in value] | ||
|
|
||
| return value | ||
|
|
||
| new_output_value = _replace_repeated_outputs(output_value) | ||
| if modified: | ||
| output_node.args = (new_output_value,) | ||
| graph.eliminate_dead_code() | ||
| graph.lint() | ||
| graph_module.recompile() | ||
| graph_module = super().call(graph_module).graph_module | ||
|
|
||
| return PassResult(graph_module, modified) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Any, List | ||
|
|
||
| import torch | ||
| import tosa_serializer as ts | ||
|
|
||
| from executorch.backends.arm.operators.node_visitor import ( | ||
| NodeVisitor, | ||
| register_node_visitor, | ||
| ) | ||
| from executorch.backends.arm.operators.operator_validation_utils import ( | ||
| validate_num_inputs, | ||
| validate_same_dtype, | ||
| validate_valid_dtype, | ||
| ) | ||
| from executorch.backends.arm.tosa.mapping import TosaArg | ||
|
|
||
|
|
||
| @register_node_visitor | ||
| class IdentityVisitor(NodeVisitor): | ||
| """Lower the TOSA IDENTITY op.""" | ||
|
|
||
| target = "tosa.IDENTITY.default" | ||
|
|
||
| def define_node( | ||
| self, | ||
| node: torch.fx.Node, | ||
| tosa_graph: Any, | ||
| inputs: List[TosaArg], | ||
| output: TosaArg, | ||
| ) -> None: | ||
| validate_num_inputs(self.target, inputs, 1) | ||
| validate_same_dtype(self.target, [inputs[0], output], ts) | ||
| validate_valid_dtype( | ||
| self.target, | ||
| [inputs[0], output], | ||
| [ | ||
| ts.DType.BOOL, | ||
| ts.DType.INT8, | ||
| ts.DType.INT16, | ||
| ts.DType.INT32, | ||
| ts.DType.FP16, | ||
| ts.DType.FP32, | ||
| ts.DType.BF16, | ||
| ], | ||
| self.tosa_spec, | ||
| ) | ||
|
Comment on lines
+36
to
+51
|
||
|
|
||
| attr = ts.TosaSerializerAttribute() | ||
| attr.IdentityAttribute() | ||
| self._serialize_operator( | ||
| node, | ||
| tosa_graph, | ||
| ts.Op.IDENTITY, | ||
| [inputs[0].name], | ||
| [output.name], | ||
| attr, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import executorch.backends.arm.tosa.dialect # noqa: F401 | ||
| import torch | ||
| from executorch.backends.arm.tosa.specification import ( | ||
| TosaLoweringContext, | ||
| TosaSpecification, | ||
| ) | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||
|
|
||
|
|
||
| def test_identity_tosa_FP() -> None: | ||
| sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) | ||
|
|
||
| with TosaLoweringContext( | ||
| TosaSpecification.create_from_string("TOSA-1.0+FP") | ||
| ), FakeTensorMode() as mode: | ||
| output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input)) | ||
|
|
||
| assert output.dtype == sample_input.dtype | ||
| assert tuple(output.shape) == tuple(sample_input.shape) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| # Copyright 2026 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass | ||
| from executorch.backends.arm.test.tester.test_pipeline import PassPipeline | ||
| from executorch.backends.test.harness.stages import StageType | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
|
|
||
|
|
||
| class DuplicateOutputModule(torch.nn.Module): | ||
| def forward(self, x: torch.Tensor): | ||
| y = x + 1.0 | ||
| return y, y | ||
|
|
||
|
|
||
| class UniqueOutputModule(torch.nn.Module): | ||
| def forward(self, x: torch.Tensor): | ||
| y = x + 1.0 | ||
| z = x + 2.0 | ||
| return y, z | ||
|
|
||
|
|
||
| def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> ( | ||
| None | ||
| ): | ||
| pipeline = PassPipeline[tuple[torch.Tensor]]( | ||
| DuplicateOutputModule(), | ||
| (torch.rand(2, 2),), | ||
| quantize=False, | ||
| pass_list=[EnsureUniqueOutputNodesPass], | ||
| ops_after_pass={ | ||
| "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2, | ||
| }, | ||
| ) | ||
| pipeline.pop_stage("run_method_and_compare_outputs") | ||
| pipeline.run() | ||
|
|
||
| graph_module = ( | ||
| pipeline.tester.get_artifact(StageType.RUN_PASSES) | ||
| .exported_program() | ||
| .graph_module | ||
| ) | ||
| output_node = graph_module.graph.output_node() | ||
| outputs = list(output_node.args[0]) | ||
|
|
||
| assert outputs[0] is not outputs[1] | ||
| assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default | ||
| assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default | ||
| assert outputs[0].args[0] is outputs[1].args[0] | ||
|
|
||
|
|
||
| def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None: | ||
| pipeline = PassPipeline[tuple[torch.Tensor]]( | ||
| UniqueOutputModule(), | ||
| (torch.rand(2, 2),), | ||
| quantize=False, | ||
| pass_list=[EnsureUniqueOutputNodesPass], | ||
| ops_not_after_pass=[ | ||
| "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default", | ||
| ], | ||
| ) | ||
| pipeline.pop_stage("run_method_and_compare_outputs") | ||
| pipeline.run() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,15 +8,19 @@ | |
| import torch | ||
| from executorch.backends.arm._passes import ( | ||
| AnnotateOutputDimOrderPass, | ||
| EnsureUniqueOutputNodesPass, | ||
| FuseEqualPlaceholdersPass, | ||
| ToTosaMemoryFormatPass, | ||
| ) | ||
|
|
||
| from executorch.backends.arm.test import common | ||
| from executorch.backends.arm.test.tester.test_pipeline import ( | ||
| PassPipeline, | ||
| TosaPipelineFP, | ||
| TosaPipelineINT, | ||
| ) | ||
| from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
|
|
||
| input_t = Tuple[torch.Tensor] # Input x | ||
|
|
||
|
|
@@ -177,6 +181,26 @@ def get_inputs(self) -> input_t: | |
| return (torch.rand(4, 4, 4, 4),) | ||
|
|
||
|
|
||
| class DuplicateConstantOutputs(torch.nn.Module): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) | ||
| self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.grid0, self.grid1, x | ||
|
|
||
|
|
||
| class DuplicateConstantOutputsWithAdd(torch.nn.Module): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) | ||
| self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.grid0, self.grid1, x + x | ||
|
|
||
|
|
||
| modules: Dict[str, ModuleMetadata] = { | ||
| "no_nhwc": NoNHWC(), | ||
| "parallel_clusters": ParallelClusters(), | ||
|
|
@@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No | |
| module_nn = cast(torch.nn.Module, module) | ||
| pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None: | ||
| pipeline = PassPipeline[input_t]( | ||
| DuplicateConstantOutputs(), | ||
| (torch.rand(1, 2, 32, 32),), | ||
| quantize=False, | ||
| pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], | ||
| passes_with_exported_program=[ | ||
| FuseEqualPlaceholdersPass, | ||
| ToTosaMemoryFormatPass, | ||
| EnsureUniqueOutputNodesPass, | ||
| ], | ||
|
Comment on lines
+244
to
+248
|
||
| ) | ||
| pipeline.pop_stage("run_method_and_compare_outputs") | ||
| pipeline.run() | ||
|
|
||
| graph_module = pipeline.tester.get_artifact().exported_program().graph_module | ||
| output_node = graph_module.graph.output_node() | ||
| outputs = list(output_node.args[0]) | ||
|
|
||
| assert outputs[0] is not outputs[1] | ||
| assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default | ||
| assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default | ||
| assert outputs[0].args[0] is outputs[1].args[0] | ||
|
|
||
|
|
||
| def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None: | ||
| pipeline = TosaPipelineFP[input_t]( | ||
| DuplicateConstantOutputsWithAdd(), | ||
| (torch.rand(1, 2, 32, 32),), | ||
| [], | ||
| [], | ||
| ) | ||
| pipeline.run() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| custom, | ||
| depthwise_conv2d, | ||
| gather, | ||
| identity, | ||
| matmul, | ||
| max_pool2d, | ||
| pad, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass inherits
ArmPass.__init__(tfa_pass=False, *args, **kwargs), so accidental construction asEnsureUniqueOutputNodesPass(exported_program)(as done by some test harness paths) will bind the exported program object totfa_passand treat the pass as a transform-for-annotation pass. Define an explicit__init__that accepts an optional/ignoredexported_programpositional parameter and always forwardstfa_pass=FalsetoArmPassto avoid this silent behavior change.