Skip to content

Commit 6f6225c

Browse files
Arm backend: Split shared get_attr nodes before dedupe (#20420)
DeduplicateGetAttrPass gives duplicate get_attr nodes different backing attrs before PT2E folding. It missed the case where one get_attr node has multiple users, for example: w = get_attr("weight") conv1 = conv2d(x, w) conv2 = conv2d(y, w) Some passes preceding DeduplicateGetAttrPass hid this issue because they rebuilt the graph through ExportPass. That rebuild could split the shared get_attr node into one node per use, even when the pass did not change any operators. Handle that case in DeduplicateGetAttrPass. Keep the original get_attr node for the first user, then create new get_attr nodes for the other users with the same target and copied metadata. The existing dedupe logic can then give those get_attr nodes different backing attrs. This keeps shared parameter handling correct without relying on graph rebuild side effects. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent 59fc93d commit 6f6225c

2 files changed

Lines changed: 63 additions & 1 deletion

File tree

backends/arm/_passes/deduplicate_get_attr_pass.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.exir.pass_base import ExportPass, PassResult
1111
from torch.fx import GraphModule, Node
12+
from torch.fx.node import map_arg
1213
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
1314

1415

@@ -24,6 +25,13 @@ class DeduplicateGetAttrPass(ArmPass):
2425

2526
_passes_required_after: Set[Type[ExportPass]] = set()
2627

28+
def _replace_input_node(self, node: Node, old_node: Node, new_node: Node) -> None:
29+
def maybe_replace_node(arg: Any) -> Any:
30+
return new_node if arg is old_node else arg
31+
32+
node.args = map_arg(node.args, maybe_replace_node)
33+
node.kwargs = map_arg(node.kwargs, maybe_replace_node)
34+
2735
def _get_attr(self, graph_module: GraphModule, target: str) -> Any:
2836
attr: Any = graph_module
2937
for target_atom in target.split("."):
@@ -51,9 +59,26 @@ def _copy_attr(self, graph_module: GraphModule, node: Node) -> str:
5159

5260
return attr_name
5361

62+
def _split_shared_get_attrs(self, graph_module: GraphModule) -> bool:
63+
modified = False
64+
65+
for node in list(graph_module.graph.find_nodes(op="get_attr")):
66+
users = list(node.users)
67+
if len(users) <= 1:
68+
continue
69+
70+
for user in users[1:]:
71+
with graph_module.graph.inserting_before(user):
72+
new_node = graph_module.graph.get_attr(node.target)
73+
new_node.meta.update(node.meta)
74+
self._replace_input_node(user, node, new_node)
75+
modified = True
76+
77+
return modified
78+
5479
def call(self, graph_module: GraphModule) -> PassResult:
5580
seen_targets: set[str] = set()
56-
modified = False
81+
modified = self._split_shared_get_attrs(graph_module)
5782

5883
for node in graph_module.graph.find_nodes(op="get_attr"):
5984

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes.deduplicate_get_attr_pass import (
8+
DeduplicateGetAttrPass,
9+
)
10+
from torch.fx import Graph, GraphModule
11+
12+
13+
def test_deduplicate_get_attr_splits_shared_node_users() -> None:
14+
root = torch.nn.Module()
15+
shared = torch.ones(2, 2)
16+
root.register_buffer("shared", shared)
17+
18+
graph = Graph()
19+
x = graph.placeholder("x")
20+
attr = graph.get_attr("shared")
21+
first = graph.call_function(torch.ops.aten.add.Tensor, (x, attr))
22+
second = graph.call_function(torch.ops.aten.sub.Tensor, (first, attr))
23+
graph.output(second)
24+
graph_module = GraphModule(root, graph)
25+
26+
result = DeduplicateGetAttrPass()(graph_module)
27+
28+
assert result is not None
29+
assert result.modified
30+
31+
get_attrs = list(graph_module.graph.find_nodes(op="get_attr"))
32+
assert len(get_attrs) == 2
33+
assert len({node.target for node in get_attrs}) == 2
34+
assert first.args[1] is get_attrs[0]
35+
assert second.args[1] is get_attrs[1]
36+
assert getattr(graph_module, get_attrs[0].target) is shared
37+
assert getattr(graph_module, get_attrs[1].target) is shared

0 commit comments

Comments
 (0)