Skip to content

Commit 5f597ab

Browse files
beomwookangbeomwoo-kang
authored andcommitted
fix(arm): validate partitions for dependency cycles after Q/DQ de-tagging
`_detag_boundary_nodes` removes Q/DQ nodes from partition boundaries after `CapabilityBasedPartitioner` has produced cycle-free partitions. However, this de-tagging can introduce dependency cycles for models with complex attention blocks (e.g. MobileViT, where CNN and Transformer ops are grouped into a single large partition). The cycle occurs because removing Q/DQ bridge nodes creates paths that exit the partition and re-enter it through the now-unpartitioned nodes, making it impossible to extract the partition as a valid subgraph. This change adds cycle validation after `_detag_boundary_nodes`. When a cycle is detected, the partition is split into connected components of the surviving (still-tagged) nodes. Each component becomes a separate partition that is individually cycle-free after de-tagging. - Add `_validate_partition()`: BFS-based cycle detection (same algorithm as `torch.fx.passes.utils.fuser_utils.validate_partition`) - Add `_find_connected_components()`: undirected graph traversal to split surviving nodes into disjoint sub-partitions - Guard the nocompute-partition `tags.remove()` against already-removed tags from the cycle-split path Tested with MobileViT-S on Ethos-U85: previously failed with `AssertionError: Invalid partition, found dependency cycles`, now successfully produces a .pte file (5.7 MB). Nine attention-block partitions are each split into 3 sub-partitions. All sub-partitions remain on NPU (no CPU fallback). Existing CNN-only models (ResNet, MobileNetV2, EfficientNet) are unaffected as their partitions have no cycles after de-tagging.
1 parent abc0237 commit 5f597ab

1 file changed

Lines changed: 99 additions & 1 deletion

File tree

backends/arm/tosa/partitioner.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
import logging
17+
from collections import deque
1718
from itertools import count
1819
from typing import Callable, List, Optional, Sequence, Tuple
1920

@@ -118,6 +119,76 @@ def reject_partition(
118119
)
119120

120121

122+
def _validate_partition(nodes: set[torch.fx.Node]) -> bool:
123+
"""Check whether a set of nodes can be extracted as a subgraph without cycles.
124+
125+
Perform a BFS from the external users of partition nodes. If any node
126+
reached by BFS is itself inside the partition, then extracting the
127+
partition would create a dependency cycle in the remaining graph.
128+
129+
Args:
130+
nodes: The set of FX nodes that form the partition.
131+
132+
Returns:
133+
True if the partition is valid (no cycles), False otherwise.
134+
135+
"""
136+
outputs: list[torch.fx.Node] = []
137+
for node in nodes:
138+
for user in node.users:
139+
if user not in nodes:
140+
outputs.append(user)
141+
142+
visited: set[torch.fx.Node] = set()
143+
queue = deque(outputs)
144+
while queue:
145+
current = queue.popleft()
146+
if current in visited:
147+
continue
148+
visited.add(current)
149+
if current in nodes:
150+
return False
151+
for user in current.users:
152+
if user not in visited:
153+
queue.append(user)
154+
return True
155+
156+
157+
def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]:
158+
"""Find connected components in a set of nodes treating edges as undirected.
159+
160+
Two nodes are connected if one is an input or user of the other and both
161+
are in ``nodes``.
162+
163+
Args:
164+
nodes: The node set to partition into components.
165+
166+
Returns:
167+
A list of disjoint node sets, one per connected component.
168+
169+
"""
170+
remaining = set(nodes)
171+
components: list[set[torch.fx.Node]] = []
172+
while remaining:
173+
seed = next(iter(remaining))
174+
component: set[torch.fx.Node] = set()
175+
queue = deque([seed])
176+
while queue:
177+
node = queue.popleft()
178+
if node in component or node not in remaining:
179+
continue
180+
component.add(node)
181+
for inp in node.all_input_nodes:
182+
if inp in remaining and inp not in component:
183+
queue.append(inp)
184+
for user in node.users:
185+
if user in remaining and user not in component:
186+
queue.append(user)
187+
remaining -= component
188+
components.append(component)
189+
return components
190+
191+
121192
class TOSAPartitioner(Partitioner):
122193
"""Partition an exported program into TOSA-delegable subgraphs.
123194
@@ -255,6 +326,32 @@ def _tag_module( # noqa
255326
reporter,
256327
)
257328

329+
# After de-tagging, the remaining tagged nodes may form
330+
# dependency cycles. This happens when models contain complex
331+
# attention blocks (e.g. MobileViT) where Q/DQ nodes act as
332+
# bridges between partition segments. Detect such cycles and
333+
# split the partition into valid connected components.
334+
surviving = {
335+
n for n in partition.nodes if is_partitioned(n, tag)
336+
}
337+
if surviving and not _validate_partition(surviving):
338+
components = _find_connected_components(surviving)
339+
logger.info(
340+
f"Partition {tag} has dependency cycle after Q/DQ "
341+
f"de-tagging. Splitting into {len(components)} "
342+
f"sub-partition(s)."
343+
)
344+
# Remove the original tag from all nodes
345+
for node in surviving:
346+
del node.meta["delegation_tag"]
347+
tags.remove(tag)
348+
# Re-tag each connected component as a new partition
349+
for component in components:
350+
new_tag = f"tag{next(tag_iterator)}"
351+
tags.add(new_tag)
352+
for node in component:
353+
node.meta["delegation_tag"] = new_tag
354+
258355
# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
259356
is_nocompute_partition = all(
260357
_is_noop_clone(node)
@@ -272,7 +369,8 @@ def _tag_module( # noqa
272369
partition,
273370
reporter,
274371
)
275-
tags.remove(tag)
372+
if tag in tags:
373+
tags.remove(tag)
276374
return tags
277375

278376
def partition(self, exported_program: ExportedProgram) -> PartitionResult:

0 commit comments

Comments
 (0)