diff --git a/backends/arm/test/misc/test_partitioner_tag_order.py b/backends/arm/test/misc/test_partitioner_tag_order.py new file mode 100644 index 00000000000..89741eb0074 --- /dev/null +++ b/backends/arm/test/misc/test_partitioner_tag_order.py @@ -0,0 +1,73 @@ +from types import SimpleNamespace + +from executorch.backends.arm.tosa import partitioner as tosa_partitioner +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec +from executorch.backends.arm.tosa.partitioner import TOSAPartitioner + + +class _FakeCapabilityBasedPartitioner: + def __init__(self, *args, **kwargs) -> None: + pass + + def propose_partitions(self): + return [ + SimpleNamespace(nodes=[SimpleNamespace(meta={}, target=f"op{idx}")]) + for idx in range(3) + ] + + +def _make_reporter() -> SimpleNamespace: + return SimpleNamespace( + report_reject=lambda *args, **kwargs: None, + get_table_report=lambda: "", + ) + + +def test_tag_module_preserves_partition_discovery_order(monkeypatch): + partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP")) + + monkeypatch.setattr( + tosa_partitioner, "get_cond_while_submodules_nested", lambda module: [] + ) + monkeypatch.setattr( + tosa_partitioner, "tosa_support_factory", lambda *args, **kwargs: object() + ) + monkeypatch.setattr( + tosa_partitioner, + "CapabilityBasedPartitioner", + _FakeCapabilityBasedPartitioner, + ) + monkeypatch.setattr( + partitioner, + "_partition_has_invalid_uint8", + lambda partition, tag: False, + ) + monkeypatch.setattr( + partitioner, + "_preserve_io_quantization_enabled", + lambda: False, + ) + + tags = partitioner._tag_module( + SimpleNamespace(graph=SimpleNamespace(nodes=[])), + SimpleNamespace(), + _make_reporter(), + ) + + assert tags == ["tag0", "tag1", "tag2"] + + +def test_partition_preserves_tag_discovery_order(monkeypatch): + partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP")) + + monkeypatch.setattr( + partitioner, + "_tag_module", + lambda *args, **kwargs: ["tag2", "tag10"], + ) + monkeypatch.setattr(tosa_partitioner, "tag_constant_data", lambda program: None) + monkeypatch.setattr(tosa_partitioner, "WhyNoPartitionReporter", _make_reporter) + + result = partitioner.partition(SimpleNamespace(graph_module=SimpleNamespace())) + + assert list(result.partition_tags) == ["tag2", "tag10"] diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 75dc2e88151..83aae73a493 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -289,7 +289,7 @@ def _tag_module( # noqa containing_program: ExportedProgram, reporter: WhyNoPartitionReporter, tag_iterator: count | None = None, - ) -> set[str]: + ) -> list[str]: """Tag nodes in a module or submodule from the containing program. Args: @@ -298,21 +298,25 @@ def _tag_module( # noqa reporter: A reporter to report why nodes were rejected. Returns: - A set of strings with the partition tags. + A list of strings with the partition tags in discovery order. """ - tags: set[str] = set() + # Preserve discovery order so backend lowering sees a deterministic + # partition order across Python processes. + tags: list[str] = [] + seen_tags: set[str] = set() if tag_iterator is None: tag_iterator = count(0) for _, submodule, _ in get_cond_while_submodules_nested(module): submodule_tags = self._tag_module( submodule, containing_program, reporter, tag_iterator ) - if len(tags & submodule_tags) != 0: + if any(tag in seen_tags for tag in submodule_tags): raise RuntimeError( "Got overlapping tags in two different modules, this shouldn't happen." ) - tags = tags | submodule_tags + tags.extend(submodule_tags) + seen_tags.update(submodule_tags) operator_support = tosa_support_factory( self.tosa_spec, containing_program, reporter, self.additional_checks ) @@ -335,7 +339,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: for partition in partition_list: tag = f"tag{next(tag_iterator)}" - tags.add(tag) + tags.append(tag) + seen_tags.add(tag) for node in partition.nodes: node.meta["delegation_tag"] = tag @@ -364,6 +369,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: reporter, ) tags.remove(tag) + seen_tags.remove(tag) continue # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." @@ -385,6 +391,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: reporter, ) tags.remove(tag) + seen_tags.remove(tag) return tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: