Skip to content

Commit f9506f8

Browse files
AdrianLundellBaris Demir
andauthored
Arm backend: Preserve duplicate output slots with TOSA identity fanout (#18866)
When FuseEqualPlaceholdersPass fuses equal constant placeholders, the graph output can contain the same node in multiple output slots. In this case ToTosaMemoryFormatPass was rewriting the output node with replace_input_with() while inserting output transposes. That rewrote all matching occurrences at once, so duplicated logical output slots were collapsed onto the same transpose node instead of remaining distinct. Fix this by handling duplicate outputs in the output rewrite path. For shared output nodes, create a single boundary TOSA TRANSPOSE and preserve distinct output slots by inserting TOSA IDENTITY fanout nodes for later duplicates. This keeps insert_input_transpose() focused on normal input rewrites, avoids duplicating equivalent transposes for shared outputs, and preserves the output slot structure expected by later lowering and serialization stages. Add regression coverage for FuseEqualPlaceholdersPass + ToTosaMemoryFormatPass with duplicate outputs, and add TOSA IDENTITY dialect and visitor coverage. --------- Signed-off-by: Baris Demir <baris.demir@arm.com> Signed-off-by: Adrian Lundell <adrian.lundell@arm.com> Co-authored-by: Baris Demir <baris.demir@arm.com>
1 parent 63217d2 commit f9506f8

11 files changed

Lines changed: 320 additions & 5 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from .decompose_var_pass import DecomposeVarPass # noqa
101101
from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa
102102
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
103+
from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa
103104
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
104105
FoldAndAnnotateQParamsPass,
105106
QuantizeClampArgumentsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
DecomposeVarPass,
9999
DecomposeWhereScalarOtherPass,
100100
DecorateFp32toInt32CastingPass,
101+
EnsureUniqueOutputNodesPass,
101102
FoldAndAnnotateQParamsPass,
102103
FuseBatchNorm2dPass,
103104
FuseConsecutiveConcatShapesPass,
@@ -544,6 +545,7 @@ def _tosa_pipeline(
544545
FuseEqualPlaceholdersPass(exported_program),
545546
FuseConsecutiveConcatShapesPass(),
546547
ToTosaMemoryFormatPass(exported_program),
548+
EnsureUniqueOutputNodesPass(),
547549
RemoveNoopPass(),
548550
InsertRescalePass(),
549551
InsertDataLayoutCastsPass(),
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from collections import Counter
7+
from typing import Any, Set, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
class EnsureUniqueOutputNodesPass(ArmPass):
17+
"""Ensure each graph output leaf references a unique producer node.
18+
19+
If the same node appears multiple times in the output structure, insert a
20+
``tosa.IDENTITY`` node for each occurrence and replace the repeated output
21+
entries with those identity nodes.
22+
23+
"""
24+
25+
_passes_required_after: Set[Type[ExportPass]] = set()
26+
27+
@staticmethod
28+
def _collect_output_nodes(
29+
output_value: Any, counts: Counter[torch.fx.Node]
30+
) -> None:
31+
if isinstance(output_value, torch.fx.Node):
32+
counts[output_value] += 1
33+
return
34+
if isinstance(output_value, (list, tuple)):
35+
for value in output_value:
36+
EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts)
37+
38+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
39+
graph = graph_module.graph
40+
output_node = graph.output_node()
41+
output_value = output_node.args[0]
42+
43+
counts: Counter[torch.fx.Node] = Counter()
44+
self._collect_output_nodes(output_value, counts)
45+
repeated_nodes = {node for node, count in counts.items() if count > 1}
46+
if not repeated_nodes:
47+
return PassResult(graph_module, False)
48+
49+
modified = False
50+
51+
def _replace_repeated_outputs(value: Any) -> Any:
52+
nonlocal modified
53+
if isinstance(value, torch.fx.Node):
54+
if value not in repeated_nodes:
55+
return value
56+
with graph.inserting_before(output_node):
57+
identity_node = create_node(
58+
graph,
59+
exir_ops.backend.tosa.IDENTITY.default,
60+
args=(value,),
61+
from_node=value,
62+
)
63+
modified = True
64+
return identity_node
65+
66+
if isinstance(value, tuple):
67+
return tuple(_replace_repeated_outputs(v) for v in value)
68+
69+
if isinstance(value, list):
70+
return [_replace_repeated_outputs(v) for v in value]
71+
72+
return value
73+
74+
new_output_value = _replace_repeated_outputs(output_value)
75+
if modified:
76+
output_node.args = (new_output_value,)
77+
graph.eliminate_dead_code()
78+
graph.lint()
79+
graph_module.recompile()
80+
graph_module = super().call(graph_module).graph_module
81+
82+
return PassResult(graph_module, modified)

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def insert_output_transpose(node, graph_module):
264264
"""Convert a producer's output to channels-last by appending a backend
265265
`TRANSPOSE` node and rewiring its users.
266266
"""
267-
268267
rank = len(get_first_fake_tensor(node).size())
269268
spatial_rank = node.meta["tosa_spatial_rank"]
270269
mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank)
@@ -383,17 +382,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
383382
if output_dim_orders is None:
384383
raise RuntimeError(f"{output_dim_orders=} is not supported.")
385384

385+
transposed_output_inputs: set[torch.fx.Node] = set()
386386
for output_node_input, output_dim_order in zip(
387387
outputs, output_dim_orders, strict=True
388388
):
389-
if output_dim_order in (
390-
NCHW_ORDER,
391-
NNCHW_ORDER,
392-
NNNCHW_ORDER,
389+
if (
390+
output_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER)
391+
and output_node_input not in transposed_output_inputs
393392
):
394393
self.insert_input_transpose(
395394
output_node, output_node_input, graph_module
396395
)
396+
transposed_output_inputs.add(output_node_input)
397397

398398
def remove_dim_order_kwargs(
399399
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
op_tosa_custom,
5353
op_tosa_depthwise_conv2d,
5454
op_tosa_gather,
55+
op_tosa_identity,
5556
op_tosa_matmul,
5657
op_tosa_max_pool2d,
5758
op_tosa_pad,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
19+
)
20+
from executorch.backends.arm.tosa.mapping import TosaArg
21+
22+
23+
@register_node_visitor
24+
class IdentityVisitor(NodeVisitor):
25+
"""Lower the TOSA IDENTITY op."""
26+
27+
target = "tosa.IDENTITY.default"
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
tosa_graph: Any,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
validate_num_inputs(self.target, inputs, 1)
37+
validate_same_dtype(self.target, [inputs[0], output], ts)
38+
validate_valid_dtype(
39+
self.target,
40+
[inputs[0], output],
41+
[
42+
ts.DType.BOOL,
43+
ts.DType.INT8,
44+
ts.DType.INT16,
45+
ts.DType.INT32,
46+
ts.DType.FP16,
47+
ts.DType.FP32,
48+
ts.DType.BF16,
49+
],
50+
self.tosa_spec,
51+
)
52+
53+
attr = ts.TosaSerializerAttribute()
54+
attr.IdentityAttribute()
55+
self._serialize_operator(
56+
node,
57+
tosa_graph,
58+
ts.Op.IDENTITY,
59+
[inputs[0].name],
60+
[output.name],
61+
attr,
62+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import torch
8+
from executorch.backends.arm.tosa.specification import (
9+
TosaLoweringContext,
10+
TosaSpecification,
11+
)
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch._subclasses.fake_tensor import FakeTensorMode
14+
15+
16+
def test_identity_tosa_FP() -> None:
17+
sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32)
18+
19+
with TosaLoweringContext(
20+
TosaSpecification.create_from_string("TOSA-1.0+FP")
21+
), FakeTensorMode() as mode:
22+
output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input))
23+
24+
assert output.dtype == sample_input.dtype
25+
assert tuple(output.shape) == tuple(sample_input.shape)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass
8+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
9+
from executorch.backends.test.harness.stages import StageType
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
13+
class DuplicateOutputModule(torch.nn.Module):
14+
def forward(self, x: torch.Tensor):
15+
y = x + 1.0
16+
return y, y
17+
18+
19+
class UniqueOutputModule(torch.nn.Module):
20+
def forward(self, x: torch.Tensor):
21+
y = x + 1.0
22+
z = x + 2.0
23+
return y, z
24+
25+
26+
def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> (
27+
None
28+
):
29+
pipeline = PassPipeline[tuple[torch.Tensor]](
30+
DuplicateOutputModule(),
31+
(torch.rand(2, 2),),
32+
quantize=False,
33+
pass_list=[EnsureUniqueOutputNodesPass],
34+
ops_after_pass={
35+
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2,
36+
},
37+
)
38+
pipeline.pop_stage("run_method_and_compare_outputs")
39+
pipeline.run()
40+
41+
graph_module = (
42+
pipeline.tester.get_artifact(StageType.RUN_PASSES)
43+
.exported_program()
44+
.graph_module
45+
)
46+
output_node = graph_module.graph.output_node()
47+
outputs = list(output_node.args[0])
48+
49+
assert outputs[0] is not outputs[1]
50+
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
51+
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
52+
assert outputs[0].args[0] is outputs[1].args[0]
53+
54+
55+
def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None:
56+
pipeline = PassPipeline[tuple[torch.Tensor]](
57+
UniqueOutputModule(),
58+
(torch.rand(2, 2),),
59+
quantize=False,
60+
pass_list=[EnsureUniqueOutputNodesPass],
61+
ops_not_after_pass=[
62+
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default",
63+
],
64+
)
65+
pipeline.pop_stage("run_method_and_compare_outputs")
66+
pipeline.run()

backends/arm/test/passes/test_to_tosa_memory_format.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
import torch
99
from executorch.backends.arm._passes import (
1010
AnnotateOutputDimOrderPass,
11+
EnsureUniqueOutputNodesPass,
12+
FuseEqualPlaceholdersPass,
1113
ToTosaMemoryFormatPass,
1214
)
1315

1416
from executorch.backends.arm.test import common
1517
from executorch.backends.arm.test.tester.test_pipeline import (
1618
PassPipeline,
19+
TosaPipelineFP,
1720
TosaPipelineINT,
1821
)
1922
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
23+
from executorch.exir.dialects._ops import ops as exir_ops
2024

2125
input_t = Tuple[torch.Tensor] # Input x
2226

@@ -177,6 +181,26 @@ def get_inputs(self) -> input_t:
177181
return (torch.rand(4, 4, 4, 4),)
178182

179183

184+
class DuplicateConstantOutputs(torch.nn.Module):
185+
def __init__(self) -> None:
186+
super().__init__()
187+
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
188+
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))
189+
190+
def forward(self, x: torch.Tensor):
191+
return self.grid0, self.grid1, x
192+
193+
194+
class DuplicateConstantOutputsWithAdd(torch.nn.Module):
195+
def __init__(self) -> None:
196+
super().__init__()
197+
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
198+
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))
199+
200+
def forward(self, x: torch.Tensor):
201+
return self.grid0, self.grid1, x + x
202+
203+
180204
modules: Dict[str, ModuleMetadata] = {
181205
"no_nhwc": NoNHWC(),
182206
"parallel_clusters": ParallelClusters(),
@@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
209233
module_nn = cast(torch.nn.Module, module)
210234
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
211235
pipeline.run()
236+
237+
238+
def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None:
239+
pipeline = PassPipeline[input_t](
240+
DuplicateConstantOutputs(),
241+
(torch.rand(1, 2, 32, 32),),
242+
quantize=False,
243+
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
244+
passes_with_exported_program=[
245+
FuseEqualPlaceholdersPass,
246+
ToTosaMemoryFormatPass,
247+
EnsureUniqueOutputNodesPass,
248+
],
249+
)
250+
pipeline.pop_stage("run_method_and_compare_outputs")
251+
pipeline.run()
252+
253+
graph_module = pipeline.tester.get_artifact().exported_program().graph_module
254+
output_node = graph_module.graph.output_node()
255+
outputs = list(output_node.args[0])
256+
257+
assert outputs[0] is not outputs[1]
258+
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
259+
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
260+
assert outputs[0].args[0] is outputs[1].args[0]
261+
262+
263+
def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None:
264+
pipeline = TosaPipelineFP[input_t](
265+
DuplicateConstantOutputsWithAdd(),
266+
(torch.rand(1, 2, 32, 32),),
267+
[],
268+
[],
269+
)
270+
pipeline.run()

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
custom,
1111
depthwise_conv2d,
1212
gather,
13+
identity,
1314
matmul,
1415
max_pool2d,
1516
pad,

0 commit comments

Comments
 (0)