Skip to content

Commit 6ae1b6a

Browse files
Baris DemirAdrianLundell
authored andcommitted
Arm backend: Preserve duplicate output slots with TOSA identity fanout
When FuseEqualPlaceholdersPass fuses equal constant placeholders, the graph output can contain the same node in multiple output slots. In this case ToTosaMemoryFormatPass was rewriting the output node with replace_input_with() while inserting output transposes. That rewrote all matching occurrences at once, so duplicated logical output slots were collapsed onto the same transpose node instead of remaining distinct. Fix this by handling duplicate outputs in the output rewrite path. For shared output nodes, create a single boundary TOSA TRANSPOSE and preserve distinct output slots by inserting TOSA IDENTITY fanout nodes for later duplicates. This keeps insert_input_transpose() focused on normal input rewrites, avoids duplicating equivalent transposes for shared outputs, and preserves the output slot structure expected by later lowering and serialization stages. Add regression coverage for FuseEqualPlaceholdersPass + ToTosaMemoryFormatPass with duplicate outputs, and add TOSA IDENTITY dialect and visitor coverage. Signed-off-by: Baris Demir <baris.demir@arm.com> Change-Id: Ie14bc88bfadaad7f993b71ef1b5332b5953b72c8
1 parent 37b12c8 commit 6ae1b6a

11 files changed

Lines changed: 319 additions & 5 deletions

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
from .decompose_var_pass import DecomposeVarPass # noqa
9898
from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa
9999
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
100+
from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa
100101
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
101102
FoldAndAnnotateQParamsPass,
102103
QuantizeClampArgumentsPass,

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
DecomposeVarPass,
9797
DecomposeWhereScalarOtherPass,
9898
DecorateFp32toInt32CastingPass,
99+
EnsureUniqueOutputNodesPass,
99100
FoldAndAnnotateQParamsPass,
100101
FuseBatchNorm2dPass,
101102
FuseConsecutiveConcatShapesPass,
@@ -502,6 +503,7 @@ def _tosa_pipeline(
502503
FuseEqualPlaceholdersPass(exported_program),
503504
FuseConsecutiveConcatShapesPass(),
504505
ToTosaMemoryFormatPass(exported_program),
506+
EnsureUniqueOutputNodesPass(),
505507
RemoveNoopPass(),
506508
InsertRescalePass(),
507509
]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from collections import Counter
7+
from typing import Any, Set, Type
8+
9+
import torch
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
16+
class EnsureUniqueOutputNodesPass(ArmPass):
17+
"""Ensure each graph output leaf references a unique producer node.
18+
19+
If the same node appears multiple times in the output structure, insert a
20+
``tosa.IDENTITY`` node for each occurrence and replace the repeated output
21+
entries with those identity nodes.
22+
23+
"""
24+
25+
_passes_required_after: Set[Type[ExportPass]] = set()
26+
27+
@staticmethod
28+
def _collect_output_nodes(
29+
output_value: Any, counts: Counter[torch.fx.Node]
30+
) -> None:
31+
if isinstance(output_value, torch.fx.Node):
32+
counts[output_value] += 1
33+
return
34+
if isinstance(output_value, (list, tuple)):
35+
for value in output_value:
36+
EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts)
37+
38+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
39+
graph = graph_module.graph
40+
output_node = graph.output_node()
41+
output_value = output_node.args[0]
42+
43+
counts: Counter[torch.fx.Node] = Counter()
44+
self._collect_output_nodes(output_value, counts)
45+
repeated_nodes = {node for node, count in counts.items() if count > 1}
46+
if not repeated_nodes:
47+
return PassResult(graph_module, False)
48+
49+
modified = False
50+
51+
def _replace_repeated_outputs(value: Any) -> Any:
52+
nonlocal modified
53+
if isinstance(value, torch.fx.Node):
54+
if value not in repeated_nodes:
55+
return value
56+
with graph.inserting_before(output_node):
57+
identity_node = create_node(
58+
graph,
59+
exir_ops.backend.tosa.IDENTITY.default,
60+
args=(value,),
61+
from_node=value,
62+
)
63+
modified = True
64+
return identity_node
65+
66+
if isinstance(value, tuple):
67+
return tuple(_replace_repeated_outputs(v) for v in value)
68+
69+
if isinstance(value, list):
70+
return [_replace_repeated_outputs(v) for v in value]
71+
72+
return value
73+
74+
new_output_value = _replace_repeated_outputs(output_value)
75+
if modified:
76+
output_node.args = (new_output_value,)
77+
graph.eliminate_dead_code()
78+
graph.lint()
79+
graph_module.recompile()
80+
graph_module = super().call(graph_module).graph_module
81+
82+
return PassResult(graph_module, modified)

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def insert_output_transpose(node, graph_module):
264264
"""Convert a producer's output to channels-last by appending a backend
265265
`TRANSPOSE` node and rewiring its users.
266266
"""
267-
268267
rank = len(get_first_fake_tensor(node).size())
269268
spatial_rank = node.meta["tosa_spatial_rank"]
270269
mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank)
@@ -383,17 +382,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
383382
if output_dim_orders is None:
384383
raise RuntimeError(f"{output_dim_orders=} is not supported.")
385384

385+
transposed_output_inputs: set[torch.fx.Node] = set()
386386
for output_node_input, output_dim_order in zip(
387387
outputs, output_dim_orders, strict=True
388388
):
389-
if output_dim_order in (
390-
NCHW_ORDER,
391-
NNCHW_ORDER,
392-
NNNCHW_ORDER,
389+
if (
390+
output_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER)
391+
and output_node_input not in transposed_output_inputs
393392
):
394393
self.insert_input_transpose(
395394
output_node, output_node_input, graph_module
396395
)
396+
transposed_output_inputs.add(output_node_input)
397397

398398
def remove_dim_order_kwargs(
399399
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
op_tosa_conv3d,
5454
op_tosa_depthwise_conv2d,
5555
op_tosa_gather,
56+
op_tosa_identity,
5657
op_tosa_matmul,
5758
op_tosa_pad,
5859
op_tosa_rescale,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
19+
)
20+
from executorch.backends.arm.tosa.mapping import TosaArg
21+
22+
23+
@register_node_visitor
24+
class IdentityVisitor(NodeVisitor):
25+
"""Lower the TOSA IDENTITY op."""
26+
27+
target = "tosa.IDENTITY.default"
28+
29+
def define_node(
30+
self,
31+
node: torch.fx.Node,
32+
tosa_graph: Any,
33+
inputs: List[TosaArg],
34+
output: TosaArg,
35+
) -> None:
36+
validate_num_inputs(self.target, inputs, 1)
37+
validate_same_dtype(self.target, [inputs[0], output], ts)
38+
validate_valid_dtype(
39+
self.target,
40+
[inputs[0], output],
41+
[
42+
ts.DType.BOOL,
43+
ts.DType.INT8,
44+
ts.DType.INT16,
45+
ts.DType.INT32,
46+
ts.DType.FP16,
47+
ts.DType.FP32,
48+
ts.DType.BF16,
49+
],
50+
self.tosa_spec,
51+
)
52+
53+
attr = ts.TosaSerializerAttribute()
54+
self._serialize_operator(
55+
node,
56+
tosa_graph,
57+
ts.Op.IDENTITY,
58+
[inputs[0].name],
59+
[output.name],
60+
attr,
61+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa.dialect # noqa: F401
7+
import torch
8+
from executorch.backends.arm.tosa.specification import (
9+
TosaLoweringContext,
10+
TosaSpecification,
11+
)
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch._subclasses.fake_tensor import FakeTensorMode
14+
15+
16+
def test_identity_tosa_FP() -> None:
17+
sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32)
18+
19+
with TosaLoweringContext(
20+
TosaSpecification.create_from_string("TOSA-1.0+FP")
21+
), FakeTensorMode() as mode:
22+
output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input))
23+
24+
assert output.dtype == sample_input.dtype
25+
assert tuple(output.shape) == tuple(sample_input.shape)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass
8+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
9+
from executorch.backends.test.harness.stages import StageType
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
12+
13+
class DuplicateOutputModule(torch.nn.Module):
14+
def forward(self, x: torch.Tensor):
15+
y = x + 1.0
16+
return y, y
17+
18+
19+
class UniqueOutputModule(torch.nn.Module):
20+
def forward(self, x: torch.Tensor):
21+
y = x + 1.0
22+
z = x + 2.0
23+
return y, z
24+
25+
26+
def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> (
27+
None
28+
):
29+
pipeline = PassPipeline[tuple[torch.Tensor]](
30+
DuplicateOutputModule(),
31+
(torch.rand(2, 2),),
32+
quantize=False,
33+
pass_list=[EnsureUniqueOutputNodesPass],
34+
ops_after_pass={
35+
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2,
36+
},
37+
)
38+
pipeline.pop_stage("run_method_and_compare_outputs")
39+
pipeline.run()
40+
41+
graph_module = (
42+
pipeline.tester.get_artifact(StageType.RUN_PASSES)
43+
.exported_program()
44+
.graph_module
45+
)
46+
output_node = graph_module.graph.output_node()
47+
outputs = list(output_node.args[0])
48+
49+
assert outputs[0] is not outputs[1]
50+
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
51+
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
52+
assert outputs[0].args[0] is outputs[1].args[0]
53+
54+
55+
def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None:
56+
pipeline = PassPipeline[tuple[torch.Tensor]](
57+
UniqueOutputModule(),
58+
(torch.rand(2, 2),),
59+
quantize=False,
60+
pass_list=[EnsureUniqueOutputNodesPass],
61+
ops_not_after_pass=[
62+
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default",
63+
],
64+
)
65+
pipeline.pop_stage("run_method_and_compare_outputs")
66+
pipeline.run()

backends/arm/test/passes/test_to_tosa_memory_format.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@
88
import torch
99
from executorch.backends.arm._passes import (
1010
AnnotateOutputDimOrderPass,
11+
EnsureUniqueOutputNodesPass,
12+
FuseEqualPlaceholdersPass,
1113
ToTosaMemoryFormatPass,
1214
)
1315

1416
from executorch.backends.arm.test import common
1517
from executorch.backends.arm.test.tester.test_pipeline import (
1618
PassPipeline,
19+
TosaPipelineFP,
1720
TosaPipelineINT,
1821
)
1922
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
23+
from executorch.exir.dialects._ops import ops as exir_ops
2024

2125
input_t = Tuple[torch.Tensor] # Input x
2226

@@ -177,6 +181,26 @@ def get_inputs(self) -> input_t:
177181
return (torch.rand(4, 4, 4, 4),)
178182

179183

184+
class DuplicateConstantOutputs(torch.nn.Module):
185+
def __init__(self) -> None:
186+
super().__init__()
187+
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
188+
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))
189+
190+
def forward(self, x: torch.Tensor):
191+
return self.grid0, self.grid1, x
192+
193+
194+
class DuplicateConstantOutputsWithAdd(torch.nn.Module):
195+
def __init__(self) -> None:
196+
super().__init__()
197+
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
198+
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))
199+
200+
def forward(self, x: torch.Tensor):
201+
return self.grid0, self.grid1, x + x
202+
203+
180204
modules: Dict[str, ModuleMetadata] = {
181205
"no_nhwc": NoNHWC(),
182206
"parallel_clusters": ParallelClusters(),
@@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
209233
module_nn = cast(torch.nn.Module, module)
210234
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
211235
pipeline.run()
236+
237+
238+
def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None:
239+
pipeline = PassPipeline[input_t](
240+
DuplicateConstantOutputs(),
241+
(torch.rand(1, 2, 32, 32),),
242+
quantize=False,
243+
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
244+
passes_with_exported_program=[
245+
FuseEqualPlaceholdersPass,
246+
ToTosaMemoryFormatPass,
247+
EnsureUniqueOutputNodesPass,
248+
],
249+
)
250+
pipeline.pop_stage("run_method_and_compare_outputs")
251+
pipeline.run()
252+
253+
graph_module = pipeline.tester.get_artifact().exported_program().graph_module
254+
output_node = graph_module.graph.output_node()
255+
outputs = list(output_node.args[0])
256+
257+
assert outputs[0] is not outputs[1]
258+
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
259+
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
260+
assert outputs[0].args[0] is outputs[1].args[0]
261+
262+
263+
def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None:
264+
pipeline = TosaPipelineFP[input_t](
265+
DuplicateConstantOutputsWithAdd(),
266+
(torch.rand(1, 2, 32, 32),),
267+
[],
268+
[],
269+
)
270+
pipeline.run()

backends/arm/tosa/dialect/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
conv3d,
99
depthwise_conv2d,
1010
gather,
11+
identity,
1112
matmul,
1213
pad,
1314
rescale,

0 commit comments

Comments
 (0)