Skip to content

Commit 90cd48f

Browse files
Arm backend: Fix crash in FuseDuplicateUsers (pytorch#20068)
Previously crashed in cases where groups appeared not ordered accordingly to graph.nodes. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 58f3e5d commit 90cd48f

2 files changed

Lines changed: 58 additions & 6 deletions

File tree

backends/arm/_passes/fuse_duplicate_users_pass.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
3434
graph = graph_module.graph
3535
modified = False
3636

37+
node_order = {node: index for index, node in enumerate(graph.nodes)}
3738
producers: Deque[Node] = deque(node for node in graph.nodes)
3839

3940
while producers:
@@ -48,7 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
4849
if len(user_nodes) < 2:
4950
continue
5051

51-
candidate_groups = self._get_candidate_groups(user_nodes)
52+
candidate_groups = self._get_candidate_groups(node_order, user_nodes)
5253

5354
signature_to_user: Dict[Tuple[Hashable, ...], Node] = {}
5455
for group in candidate_groups:
@@ -84,7 +85,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8485

8586
return PassResult(graph_module, modified)
8687

87-
def _get_candidate_groups(self, user_nodes):
88+
def _get_candidate_groups(self, node_order, user_nodes):
8889
users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {}
8990
for user in user_nodes:
9091
if user.graph is None:
@@ -98,9 +99,12 @@ def _get_candidate_groups(self, user_nodes):
9899
target_signature = (user.op, target_key)
99100
users_by_target.setdefault(target_signature, []).append(user)
100101

101-
candidate_groups = [
102-
group for group in users_by_target.values() if len(group) > 1
103-
]
102+
candidate_groups = []
103+
for group in users_by_target.values():
104+
if len(group) > 1:
105+
candidate_groups.append(
106+
sorted(group, key=lambda node: node_order[node])
107+
)
104108

105109
return candidate_groups
106110

backends/arm/test/passes/test_fuse_duplicate_users_pass.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -9,6 +9,7 @@
99
from executorch.backends.arm._passes import FuseDuplicateUsersPass
1010
from executorch.backends.arm.test import common
1111
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
12+
from torch.fx import Graph, GraphModule
1213

1314
input_t = Tuple[torch.Tensor] # Input x
1415

@@ -55,6 +56,42 @@ def forward(self, x):
5556
}
5657

5758

59+
def _set_val(node, val):
60+
node.meta["val"] = val
61+
return node
62+
63+
64+
def _graph_with_users_not_in_node_order() -> GraphModule:
65+
graph = Graph()
66+
x = _set_val(graph.placeholder("x"), torch.ones(1))
67+
y = _set_val(graph.placeholder("y"), torch.ones(1))
68+
69+
later_duplicate = _set_val(
70+
graph.call_function(torch.ops.aten.add.Tensor, (x, y)), torch.ones(1)
71+
)
72+
with graph.inserting_before(later_duplicate):
73+
earlier_duplicate = _set_val(
74+
graph.call_function(torch.ops.aten.add.Tensor, (x, y)), torch.ones(1)
75+
)
76+
consumer = _set_val(
77+
graph.call_function(torch.ops.aten.neg.default, (earlier_duplicate,)),
78+
torch.ones(1),
79+
)
80+
81+
output = graph.output(consumer)
82+
output.meta["val"] = torch.ones(1)
83+
graph.lint()
84+
return GraphModule(torch.nn.Module(), graph)
85+
86+
87+
def _add_node_names(graph_module):
88+
return [
89+
node.name
90+
for node in graph_module.graph.nodes
91+
if node.target == torch.ops.aten.add.Tensor
92+
]
93+
94+
5895
@common.parametrize("module", modules)
5996
def test_fuse_duplicate_users_tosa_FP(module: ModuleWithOps):
6097
pipeline = PassPipeline[input_t](
@@ -68,3 +105,14 @@ def test_fuse_duplicate_users_tosa_FP(module: ModuleWithOps):
68105
],
69106
)
70107
pipeline.run()
108+
109+
110+
def test_fuse_duplicate_users_preserves_graph_order_for_representative():
111+
graph_module = _graph_with_users_not_in_node_order()
112+
assert _add_node_names(graph_module) == ["add_tensor_1", "add_tensor"]
113+
114+
result = FuseDuplicateUsersPass()(graph_module)
115+
116+
result.graph_module.graph.lint()
117+
assert result.modified
118+
assert len(_add_node_names(result.graph_module)) == 1

0 commit comments

Comments
 (0)