Skip to content

Commit f2e5c7c

Browse files
Praneeth Yashovardhan KademPraneeth Yashovardhan Kadem
authored andcommitted
Preserve Arm partition tag order
1 parent c48ea12 commit f2e5c7c

2 files changed

Lines changed: 93 additions & 6 deletions

File tree

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from types import SimpleNamespace
7+
8+
from executorch.backends.arm.tosa import partitioner as tosa_partitioner
9+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
10+
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
11+
12+
13+
class _FakeCapabilityBasedPartitioner:
14+
def __init__(self, *args, **kwargs) -> None:
15+
pass
16+
17+
def propose_partitions(self):
18+
return [
19+
SimpleNamespace(nodes=[SimpleNamespace(meta={}, target=f"op{idx}")])
20+
for idx in range(3)
21+
]
22+
23+
24+
def _make_reporter() -> SimpleNamespace:
25+
return SimpleNamespace(
26+
report_reject=lambda *args, **kwargs: None,
27+
get_table_report=lambda: "",
28+
)
29+
30+
31+
def test_tag_module_preserves_partition_discovery_order(monkeypatch):
32+
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))
33+
34+
monkeypatch.setattr(
35+
tosa_partitioner, "get_cond_while_submodules_nested", lambda module: []
36+
)
37+
monkeypatch.setattr(
38+
tosa_partitioner, "tosa_support_factory", lambda *args, **kwargs: object()
39+
)
40+
monkeypatch.setattr(
41+
tosa_partitioner,
42+
"CapabilityBasedPartitioner",
43+
_FakeCapabilityBasedPartitioner,
44+
)
45+
monkeypatch.setattr(
46+
partitioner,
47+
"_partition_has_invalid_uint8",
48+
lambda partition, tag: False,
49+
)
50+
monkeypatch.setattr(
51+
partitioner,
52+
"_preserve_io_quantization_enabled",
53+
lambda: False,
54+
)
55+
56+
tags = partitioner._tag_module(
57+
SimpleNamespace(graph=SimpleNamespace(nodes=[])),
58+
SimpleNamespace(),
59+
_make_reporter(),
60+
)
61+
62+
assert tags == ["tag0", "tag1", "tag2"]
63+
64+
65+
def test_partition_preserves_tag_discovery_order(monkeypatch):
66+
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))
67+
68+
monkeypatch.setattr(
69+
partitioner,
70+
"_tag_module",
71+
lambda *args, **kwargs: ["tag2", "tag10"],
72+
)
73+
monkeypatch.setattr(tosa_partitioner, "tag_constant_data", lambda program: None)
74+
monkeypatch.setattr(
75+
tosa_partitioner, "WhyNoPartitionReporter", _make_reporter
76+
)
77+
78+
result = partitioner.partition(SimpleNamespace(graph_module=SimpleNamespace()))
79+
80+
assert list(result.partition_tags) == ["tag2", "tag10"]

backends/arm/tosa/partitioner.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _tag_module( # noqa
289289
containing_program: ExportedProgram,
290290
reporter: WhyNoPartitionReporter,
291291
tag_iterator: count | None = None,
292-
) -> set[str]:
292+
) -> list[str]:
293293
"""Tag nodes in a module or submodule from the containing program.
294294
295295
Args:
@@ -298,21 +298,25 @@ def _tag_module( # noqa
298298
reporter: A reporter to report why nodes were rejected.
299299
300300
Returns:
301-
A set of strings with the partition tags.
301+
A list of strings with the partition tags in discovery order.
302302
303303
"""
304-
tags: set[str] = set()
304+
# Preserve discovery order so backend lowering sees a deterministic
305+
# partition order across Python processes.
306+
tags: list[str] = []
307+
seen_tags: set[str] = set()
305308
if tag_iterator is None:
306309
tag_iterator = count(0)
307310
for _, submodule, _ in get_cond_while_submodules_nested(module):
308311
submodule_tags = self._tag_module(
309312
submodule, containing_program, reporter, tag_iterator
310313
)
311-
if len(tags & submodule_tags) != 0:
314+
if any(tag in seen_tags for tag in submodule_tags):
312315
raise RuntimeError(
313316
"Got overlapping tags in two different modules, this shouldn't happen."
314317
)
315-
tags = tags | submodule_tags
318+
tags.extend(submodule_tags)
319+
seen_tags.update(submodule_tags)
316320
operator_support = tosa_support_factory(
317321
self.tosa_spec, containing_program, reporter, self.additional_checks
318322
)
@@ -335,7 +339,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
335339

336340
for partition in partition_list:
337341
tag = f"tag{next(tag_iterator)}"
338-
tags.add(tag)
342+
tags.append(tag)
343+
seen_tags.add(tag)
339344

340345
for node in partition.nodes:
341346
node.meta["delegation_tag"] = tag
@@ -364,6 +369,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
364369
reporter,
365370
)
366371
tags.remove(tag)
372+
seen_tags.remove(tag)
367373
continue
368374

369375
# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
@@ -385,6 +391,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
385391
reporter,
386392
)
387393
tags.remove(tag)
394+
seen_tags.remove(tag)
388395
return tags
389396

390397
def partition(self, exported_program: ExportedProgram) -> PartitionResult:

0 commit comments

Comments
 (0)