Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions backends/arm/test/misc/test_partition_cycle_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.tosa.partitioner import (
_find_connected_components,
_validate_partition,
)


def _build_linear_graph():
"""Build a linear graph: x -> a -> b -> c -> output.

Returns the graph and nodes (x, a, b, c, output).
"""
graph = torch.fx.Graph()
x = graph.placeholder("x")
a = graph.call_function(torch.add, (x, x))
b = graph.call_function(torch.mul, (a, a))
c = graph.call_function(torch.sub, (b, b))
output = graph.output(c)
return graph, (x, a, b, c, output)


class TestValidatePartition(unittest.TestCase):
def test_contiguous_partition_is_valid(self):
"""A contiguous slice of a linear graph has no cycle."""
_, (_, a, b, _, _) = _build_linear_graph()
self.assertTrue(_validate_partition({a, b}))

def test_non_contiguous_partition_has_cycle(self):
"""Nodes {a, c} with b in between create a cycle: extracting a and c
would force b to depend on a (inside) and c to depend on b (outside),
while c is also inside.
"""
_, (_, a, _, c, _) = _build_linear_graph()
self.assertFalse(_validate_partition({a, c}))

def test_single_node_is_valid(self):
_, (_, a, _, _, _) = _build_linear_graph()
self.assertTrue(_validate_partition({a}))

def test_full_graph_interior_is_valid(self):
"""All interior nodes form a valid partition."""
_, (_, a, b, c, _) = _build_linear_graph()
self.assertTrue(_validate_partition({a, b, c}))


class TestFindConnectedComponents(unittest.TestCase):
def test_single_component(self):
_, (_, a, b, _, _) = _build_linear_graph()
components = _find_connected_components({a, b})
self.assertEqual(len(components), 1)
self.assertEqual(components[0], {a, b})

def test_disconnected_components(self):
"""Nodes {a, c} with b not in the set form two components."""
_, (_, a, _, c, _) = _build_linear_graph()
components = _find_connected_components({a, c})
self.assertEqual(len(components), 2)
component_sets = [frozenset(c) for c in components]
self.assertIn(frozenset({a}), component_sets)
self.assertIn(frozenset({c}), component_sets)

def test_empty_set(self):
components = _find_connected_components(set())
self.assertEqual(len(components), 0)

def test_branching_graph(self):
"""Graph with a fork: x -> a -> b, x -> a -> c. {b, c} are disconnected
when a is excluded."""
graph = torch.fx.Graph()
x = graph.placeholder("x")
a = graph.call_function(torch.add, (x, x))
b = graph.call_function(torch.mul, (a, a))
c = graph.call_function(torch.sub, (a, a))
_ = graph.output((b, c))

components = _find_connected_components({b, c})
self.assertEqual(len(components), 2)

# With a included, all three form one component
components = _find_connected_components({a, b, c})
self.assertEqual(len(components), 1)


if __name__ == "__main__":
unittest.main()
156 changes: 132 additions & 24 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
"""

import logging
from collections import deque
from itertools import count
from typing import Callable, List, Optional, Sequence, Tuple
from typing import Callable, Iterable, List, Optional, Sequence, Tuple

import torch
from executorch.backends.arm._passes.arm_pass_utils import (
Expand All @@ -42,7 +43,7 @@
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,18 +112,19 @@ def is_partitioned(


def reject_partition(
reason: str, partition: Partition, reporter: WhyNoPartitionReporter
reason: str,
nodes: Iterable[torch.fx.Node],
reporter: WhyNoPartitionReporter,
) -> None:
"""Remove a proposed partition and record the rejection reason.

Args:
reason (str): Human-readable explanation for rejection.
partition (object): Proposed partition object from the
capability partitioner.
nodes: The nodes to de-tag.
reporter (WhyNoPartitionReporter): used to report why nodes were rejected.

"""
for node in partition.nodes:
for node in nodes:
if "delegation_tag" in node.meta:
del node.meta["delegation_tag"]
reporter.report_reject(
Expand All @@ -131,6 +133,77 @@ def reject_partition(
)


def _validate_partition(nodes: set[torch.fx.Node]) -> bool:
"""Check whether a set of nodes can be extracted as a subgraph without
cycles.

Perform a BFS from the external users of partition nodes. If any node
reached by BFS is itself inside the partition, then extracting the
partition would create a dependency cycle in the remaining graph.

Args:
nodes: The set of FX nodes that form the partition.

Returns:
True if the partition is valid (no cycles), False otherwise.

"""
outputs: list[torch.fx.Node] = []
for node in nodes:
for user in node.users:
if user not in nodes:
outputs.append(user)

visited: set[torch.fx.Node] = set()
queue = deque(outputs)
while queue:
current = queue.popleft()
if current in visited:
continue
visited.add(current)
if current in nodes:
return False
for user in current.users:
if user not in visited:
queue.append(user)
return True


def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]:
"""Find connected components in a set of nodes treating edges as undirected.

Two nodes are connected if one is an input or user of the other and both
are in ``nodes``.

Args:
nodes: The node set to partition into components.

Returns:
A list of disjoint node sets, one per connected component.

"""
remaining = set(nodes)
components: list[set[torch.fx.Node]] = []
while remaining:
seed = next(iter(remaining))
component: set[torch.fx.Node] = set()
queue = deque([seed])
while queue:
node = queue.popleft()
if node in component or node not in remaining:
continue
component.add(node)
for inp in node.all_input_nodes:
if inp in remaining and inp not in component:
queue.append(inp)
for user in node.users:
if user in remaining and user not in component:
queue.append(user)
remaining -= component
components.append(component)
return components


class TOSAPartitioner(Partitioner):
"""Partition an exported program into TOSA-delegable subgraphs.

Expand Down Expand Up @@ -285,25 +358,60 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
reporter,
)

# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
is_nocompute_partition = all(
_is_noop_clone(node)
or _is_noop_alias_copy(node)
or _is_noop_expand(node)
or _is_noop_detach_copy(node)
or _is_noop_to_dim_order_copy(node)
or _is_view_copy(node)
or node.target in Q_OPS
or node.target in DQ_OPS
for node in partition.nodes
)
if is_nocompute_partition:
reject_partition(
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
partition,
reporter,
# After de-tagging, the remaining tagged nodes may form
# dependency cycles. This happens when models contain complex
# attention blocks (e.g. MobileViT) where Q/DQ nodes act as
# bridges between partition segments. Detect such cycles and
# split the partition into valid connected components.
surviving = {n for n in partition.nodes if is_partitioned(n, tag)}
if surviving and not _validate_partition(surviving):
components = _find_connected_components(surviving)
logger.info(
f"Partition {tag} has dependency cycle after Q/DQ "
f"de-tagging. Splitting into {len(components)} "
f"sub-partition(s)."
)
# Remove the original tag from all nodes
for node in surviving:
del node.meta["delegation_tag"]
tags.remove(tag)
# Re-tag each connected component as a new partition
for component in components:
new_tag = f"tag{next(tag_iterator)}"
tags.add(new_tag)
for node in component:
node.meta["delegation_tag"] = new_tag

# Check whether the partition contains only no-op or non-computational
# ops. Such partitions don't make sense to delegate, and in the worst
# case may be optimized away during lowering, which can break
# compilation. After a cycle split the nodes may belong to multiple
# sub-partitions, so collect every active tag and check each group.
active_tags: dict[str, list[torch.fx.Node]] = {}
for node in partition.nodes:
node_tag = node.meta.get("delegation_tag")
if node_tag is not None and node_tag in tags:
active_tags.setdefault(node_tag, []).append(node)

for active_tag, nodes in active_tags.items():
is_nocompute_partition = all(
_is_noop_clone(node)
or _is_noop_alias_copy(node)
or _is_noop_expand(node)
or _is_noop_detach_copy(node)
or _is_noop_to_dim_order_copy(node)
or _is_view_copy(node)
or node.target in Q_OPS
or node.target in DQ_OPS
for node in nodes
)
tags.remove(tag)
if is_nocompute_partition:
reject_partition(
"Partition contained only ops which are removed in the TOSA lowering, leading to an empty partition.",
nodes,
reporter,
)
tags.remove(active_tag)
return tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down
Loading