Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions backends/arm/test/misc/test_partitioner_tag_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2026 Arm Limited and/or its affiliates.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is needed here if you are not a Arm employee. Once/if we touch the file we will add the copyright.

#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from types import SimpleNamespace

from executorch.backends.arm.tosa import partitioner as tosa_partitioner
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner


class _FakeCapabilityBasedPartitioner:
def __init__(self, *args, **kwargs) -> None:
pass

def propose_partitions(self):
return [
SimpleNamespace(nodes=[SimpleNamespace(meta={}, target=f"op{idx}")])
for idx in range(3)
]


def _make_reporter() -> SimpleNamespace:
return SimpleNamespace(
report_reject=lambda *args, **kwargs: None,
get_table_report=lambda: "",
)


def test_tag_module_preserves_partition_discovery_order(monkeypatch):
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))

monkeypatch.setattr(
tosa_partitioner, "get_cond_while_submodules_nested", lambda module: []
)
monkeypatch.setattr(
tosa_partitioner, "tosa_support_factory", lambda *args, **kwargs: object()
)
monkeypatch.setattr(
tosa_partitioner,
"CapabilityBasedPartitioner",
_FakeCapabilityBasedPartitioner,
)
monkeypatch.setattr(
partitioner,
"_partition_has_invalid_uint8",
lambda partition, tag: False,
)
monkeypatch.setattr(
partitioner,
"_preserve_io_quantization_enabled",
lambda: False,
)

tags = partitioner._tag_module(
SimpleNamespace(graph=SimpleNamespace(nodes=[])),
SimpleNamespace(),
_make_reporter(),
)

assert tags == ["tag0", "tag1", "tag2"]


def test_partition_preserves_tag_discovery_order(monkeypatch):
partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP"))

monkeypatch.setattr(
partitioner,
"_tag_module",
lambda *args, **kwargs: ["tag2", "tag10"],
)
monkeypatch.setattr(tosa_partitioner, "tag_constant_data", lambda program: None)
monkeypatch.setattr(
tosa_partitioner, "WhyNoPartitionReporter", _make_reporter
)

result = partitioner.partition(SimpleNamespace(graph_module=SimpleNamespace()))

assert list(result.partition_tags) == ["tag2", "tag10"]
19 changes: 13 additions & 6 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _tag_module( # noqa
containing_program: ExportedProgram,
reporter: WhyNoPartitionReporter,
tag_iterator: count | None = None,
) -> set[str]:
) -> list[str]:
"""Tag nodes in a module or submodule from the containing program.

Args:
Expand All @@ -298,21 +298,25 @@ def _tag_module( # noqa
reporter: A reporter to report why nodes were rejected.

Returns:
A set of strings with the partition tags.
A list of strings with the partition tags in discovery order.

"""
tags: set[str] = set()
# Preserve discovery order so backend lowering sees a deterministic
# partition order across Python processes.
tags: list[str] = []
seen_tags: set[str] = set()
if tag_iterator is None:
tag_iterator = count(0)
for _, submodule, _ in get_cond_while_submodules_nested(module):
submodule_tags = self._tag_module(
submodule, containing_program, reporter, tag_iterator
)
if len(tags & submodule_tags) != 0:
if any(tag in seen_tags for tag in submodule_tags):
raise RuntimeError(
"Got overlapping tags in two different modules, this shouldn't happen."
)
tags = tags | submodule_tags
tags.extend(submodule_tags)
seen_tags.update(submodule_tags)
operator_support = tosa_support_factory(
self.tosa_spec, containing_program, reporter, self.additional_checks
)
Expand All @@ -335,7 +339,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:

for partition in partition_list:
tag = f"tag{next(tag_iterator)}"
tags.add(tag)
tags.append(tag)
seen_tags.add(tag)

for node in partition.nodes:
node.meta["delegation_tag"] = tag
Expand Down Expand Up @@ -364,6 +369,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
reporter,
)
tags.remove(tag)
seen_tags.remove(tag)
continue

# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
Expand All @@ -385,6 +391,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
reporter,
)
tags.remove(tag)
seen_tags.remove(tag)
return tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down
Loading