diff --git a/backends/arm/test/misc/test_partition_cycle_detection.py b/backends/arm/test/misc/test_partition_cycle_detection.py new file mode 100644 index 00000000000..288204d9759 --- /dev/null +++ b/backends/arm/test/misc/test_partition_cycle_detection.py @@ -0,0 +1,92 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.tosa.partitioner import ( + _find_connected_components, + _validate_partition, +) + + +def _build_linear_graph(): + """Build a linear graph: x -> a -> b -> c -> output. + + Returns the graph and nodes (x, a, b, c, output). + """ + graph = torch.fx.Graph() + x = graph.placeholder("x") + a = graph.call_function(torch.add, (x, x)) + b = graph.call_function(torch.mul, (a, a)) + c = graph.call_function(torch.sub, (b, b)) + output = graph.output(c) + return graph, (x, a, b, c, output) + + +class TestValidatePartition(unittest.TestCase): + def test_contiguous_partition_is_valid(self): + """A contiguous slice of a linear graph has no cycle.""" + _, (_, a, b, _, _) = _build_linear_graph() + self.assertTrue(_validate_partition({a, b})) + + def test_non_contiguous_partition_has_cycle(self): + """Nodes {a, c} with b in between create a cycle: extracting a and c + would force b to depend on a (inside) and c to depend on b (outside), + while c is also inside. + """ + _, (_, a, _, c, _) = _build_linear_graph() + self.assertFalse(_validate_partition({a, c})) + + def test_single_node_is_valid(self): + _, (_, a, _, _, _) = _build_linear_graph() + self.assertTrue(_validate_partition({a})) + + def test_full_graph_interior_is_valid(self): + """All interior nodes form a valid partition.""" + _, (_, a, b, c, _) = _build_linear_graph() + self.assertTrue(_validate_partition({a, b, c})) + + +class TestFindConnectedComponents(unittest.TestCase): + def test_single_component(self): + _, (_, a, b, _, _) = _build_linear_graph() + components = _find_connected_components({a, b}) + self.assertEqual(len(components), 1) + self.assertEqual(components[0], {a, b}) + + def test_disconnected_components(self): + """Nodes {a, c} with b not in the set form two components.""" + _, (_, a, _, c, _) = _build_linear_graph() + components = _find_connected_components({a, c}) + self.assertEqual(len(components), 2) + component_sets = [frozenset(c) for c in components] + self.assertIn(frozenset({a}), component_sets) + self.assertIn(frozenset({c}), component_sets) + + def test_empty_set(self): + components = _find_connected_components(set()) + self.assertEqual(len(components), 0) + + def test_branching_graph(self): + """Graph with a fork: x -> a -> b, x -> a -> c. {b, c} are disconnected + when a is excluded.""" + graph = torch.fx.Graph() + x = graph.placeholder("x") + a = graph.call_function(torch.add, (x, x)) + b = graph.call_function(torch.mul, (a, a)) + c = graph.call_function(torch.sub, (a, a)) + _ = graph.output((b, c)) + + components = _find_connected_components({b, c}) + self.assertEqual(len(components), 2) + + # With a included, all three form one component + components = _find_connected_components({a, b, c}) + self.assertEqual(len(components), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index a7ef79abbef..7b4d068e9b8 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -14,8 +14,9 @@ """ import logging +from collections import deque from itertools import count -from typing import Callable, List, Optional, Sequence, Tuple +from typing import Callable, Iterable, List, Optional, Sequence, Tuple import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -42,7 +43,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import any_chain, OperatorSupportBase logger = logging.getLogger(__name__) @@ -111,18 +112,19 @@ def is_partitioned( def reject_partition( - reason: str, partition: Partition, reporter: WhyNoPartitionReporter + reason: str, + nodes: Iterable[torch.fx.Node], + reporter: WhyNoPartitionReporter, ) -> None: """Remove a proposed partition and record the rejection reason. Args: reason (str): Human-readable explanation for rejection. - partition (object): Proposed partition object from the - capability partitioner. + nodes: The nodes to de-tag. reporter (WhyNoPartitionReporter): used to report why nodes were rejected. """ - for node in partition.nodes: + for node in nodes: if "delegation_tag" in node.meta: del node.meta["delegation_tag"] reporter.report_reject( @@ -131,6 +133,77 @@ def reject_partition( ) +def _validate_partition(nodes: set[torch.fx.Node]) -> bool: + """Check whether a set of nodes can be extracted as a subgraph without + cycles. + + Perform a BFS from the external users of partition nodes. If any node + reached by BFS is itself inside the partition, then extracting the + partition would create a dependency cycle in the remaining graph. + + Args: + nodes: The set of FX nodes that form the partition. + + Returns: + True if the partition is valid (no cycles), False otherwise. + + """ + outputs: list[torch.fx.Node] = [] + for node in nodes: + for user in node.users: + if user not in nodes: + outputs.append(user) + + visited: set[torch.fx.Node] = set() + queue = deque(outputs) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + if current in nodes: + return False + for user in current.users: + if user not in visited: + queue.append(user) + return True + + +def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]: + """Find connected components in a set of nodes treating edges as undirected. + + Two nodes are connected if one is an input or user of the other and both + are in ``nodes``. + + Args: + nodes: The node set to partition into components. + + Returns: + A list of disjoint node sets, one per connected component. + + """ + remaining = set(nodes) + components: list[set[torch.fx.Node]] = [] + while remaining: + seed = next(iter(remaining)) + component: set[torch.fx.Node] = set() + queue = deque([seed]) + while queue: + node = queue.popleft() + if node in component or node not in remaining: + continue + component.add(node) + for inp in node.all_input_nodes: + if inp in remaining and inp not in component: + queue.append(inp) + for user in node.users: + if user in remaining and user not in component: + queue.append(user) + remaining -= component + components.append(component) + return components + + class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -285,25 +358,60 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: reporter, ) - # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." - is_nocompute_partition = all( - _is_noop_clone(node) - or _is_noop_alias_copy(node) - or _is_noop_expand(node) - or _is_noop_detach_copy(node) - or _is_noop_to_dim_order_copy(node) - or _is_view_copy(node) - or node.target in Q_OPS - or node.target in DQ_OPS - for node in partition.nodes - ) - if is_nocompute_partition: - reject_partition( - "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", - partition, - reporter, + # After de-tagging, the remaining tagged nodes may form + # dependency cycles. This happens when models contain complex + # attention blocks (e.g. MobileViT) where Q/DQ nodes act as + # bridges between partition segments. Detect such cycles and + # split the partition into valid connected components. + surviving = {n for n in partition.nodes if is_partitioned(n, tag)} + if surviving and not _validate_partition(surviving): + components = _find_connected_components(surviving) + logger.info( + f"Partition {tag} has dependency cycle after Q/DQ " + f"de-tagging. Splitting into {len(components)} " + f"sub-partition(s)." + ) + # Remove the original tag from all nodes + for node in surviving: + del node.meta["delegation_tag"] + tags.remove(tag) + # Re-tag each connected component as a new partition + for component in components: + new_tag = f"tag{next(tag_iterator)}" + tags.add(new_tag) + for node in component: + node.meta["delegation_tag"] = new_tag + + # Check whether the partition contains only no-op or non-computational + # ops. Such partitions don't make sense to delegate, and in the worst + # case may be optimized away during lowering, which can break + # compilation. After a cycle split the nodes may belong to multiple + # sub-partitions, so collect every active tag and check each group. + active_tags: dict[str, list[torch.fx.Node]] = {} + for node in partition.nodes: + node_tag = node.meta.get("delegation_tag") + if node_tag is not None and node_tag in tags: + active_tags.setdefault(node_tag, []).append(node) + + for active_tag, nodes in active_tags.items(): + is_nocompute_partition = all( + _is_noop_clone(node) + or _is_noop_alias_copy(node) + or _is_noop_expand(node) + or _is_noop_detach_copy(node) + or _is_noop_to_dim_order_copy(node) + or _is_view_copy(node) + or node.target in Q_OPS + or node.target in DQ_OPS + for node in nodes ) - tags.remove(tag) + if is_nocompute_partition: + reject_partition( + "Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.", + nodes, + reporter, + ) + tags.remove(active_tag) return tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: