Skip to content

Commit c94b062

Browse files
committed
fix(arm): validate partitions for dependency cycles after Q/DQ de-tagging
`_detag_boundary_nodes` removes Q/DQ nodes from partition boundaries after `CapabilityBasedPartitioner` has produced cycle-free partitions. However, this de-tagging can introduce dependency cycles for models with complex attention blocks (e.g. MobileViT, where CNN and Transformer ops are grouped into a single large partition). The cycle occurs because removing Q/DQ bridge nodes creates paths that exit the partition and re-enter it through the now-unpartitioned nodes, making it impossible to extract the partition as a valid subgraph. This change adds cycle validation after `_detag_boundary_nodes`. When a cycle is detected, the partition is split into connected components of the surviving (still-tagged) nodes. Each component becomes a separate partition that is individually cycle-free after de-tagging. - Add `_validate_partition()`: BFS-based cycle detection (same algorithm as `torch.fx.passes.utils.fuser_utils.validate_partition`) - Add `_find_connected_components()`: undirected graph traversal to split surviving nodes into disjoint sub-partitions - Guard the nocompute-partition `tags.remove()` against already-removed tags from the cycle-split path Tested with MobileViT-S on Ethos-U85: previously failed with `AssertionError: Invalid partition, found dependency cycles`, now successfully produces a .pte file (5.7 MB). Nine attention-block partitions are each split into 3 sub-partitions. All sub-partitions remain on NPU (no CPU fallback). Existing CNN-only models (ResNet, MobileNetV2, EfficientNet) are unaffected as their partitions have no cycles after de-tagging.
1 parent 87e65ac commit c94b062

1 file changed

Lines changed: 98 additions & 1 deletion

File tree

backends/arm/tosa/partitioner.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
import logging
17+
from collections import deque
1718
from itertools import count
1819
from typing import Callable, List, Optional, Sequence, Tuple
1920

@@ -131,6 +132,77 @@ def reject_partition(
131132
)
132133

133134

135+
def _validate_partition(nodes: set[torch.fx.Node]) -> bool:
136+
"""Check whether a set of nodes can be extracted as a subgraph without
137+
cycles.
138+
139+
Perform a BFS from the external users of partition nodes. If any node
140+
reached by BFS is itself inside the partition, then extracting the
141+
partition would create a dependency cycle in the remaining graph.
142+
143+
Args:
144+
nodes: The set of FX nodes that form the partition.
145+
146+
Returns:
147+
True if the partition is valid (no cycles), False otherwise.
148+
149+
"""
150+
outputs: list[torch.fx.Node] = []
151+
for node in nodes:
152+
for user in node.users:
153+
if user not in nodes:
154+
outputs.append(user)
155+
156+
visited: set[torch.fx.Node] = set()
157+
queue = deque(outputs)
158+
while queue:
159+
current = queue.popleft()
160+
if current in visited:
161+
continue
162+
visited.add(current)
163+
if current in nodes:
164+
return False
165+
for user in current.users:
166+
if user not in visited:
167+
queue.append(user)
168+
return True
169+
170+
171+
def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]:
172+
"""Find connected components in a set of nodes treating edges as undirected.
173+
174+
Two nodes are connected if one is an input or user of the other and both
175+
are in ``nodes``.
176+
177+
Args:
178+
nodes: The node set to partition into components.
179+
180+
Returns:
181+
A list of disjoint node sets, one per connected component.
182+
183+
"""
184+
remaining = set(nodes)
185+
components: list[set[torch.fx.Node]] = []
186+
while remaining:
187+
seed = next(iter(remaining))
188+
component: set[torch.fx.Node] = set()
189+
queue = deque([seed])
190+
while queue:
191+
node = queue.popleft()
192+
if node in component or node not in remaining:
193+
continue
194+
component.add(node)
195+
for inp in node.all_input_nodes:
196+
if inp in remaining and inp not in component:
197+
queue.append(inp)
198+
for user in node.users:
199+
if user in remaining and user not in component:
200+
queue.append(user)
201+
remaining -= component
202+
components.append(component)
203+
return components
204+
205+
134206
class TOSAPartitioner(Partitioner):
135207
"""Partition an exported program into TOSA-delegable subgraphs.
136208
@@ -285,6 +357,30 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
285357
reporter,
286358
)
287359

360+
# After de-tagging, the remaining tagged nodes may form
361+
# dependency cycles. This happens when models contain complex
362+
# attention blocks (e.g. MobileViT) where Q/DQ nodes act as
363+
# bridges between partition segments. Detect such cycles and
364+
# split the partition into valid connected components.
365+
surviving = {n for n in partition.nodes if is_partitioned(n, tag)}
366+
if surviving and not _validate_partition(surviving):
367+
components = _find_connected_components(surviving)
368+
logger.info(
369+
f"Partition {tag} has dependency cycle after Q/DQ "
370+
f"de-tagging. Splitting into {len(components)} "
371+
f"sub-partition(s)."
372+
)
373+
# Remove the original tag from all nodes
374+
for node in surviving:
375+
del node.meta["delegation_tag"]
376+
tags.remove(tag)
377+
# Re-tag each connected component as a new partition
378+
for component in components:
379+
new_tag = f"tag{next(tag_iterator)}"
380+
tags.add(new_tag)
381+
for node in component:
382+
node.meta["delegation_tag"] = new_tag
383+
288384
# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
289385
is_nocompute_partition = all(
290386
_is_noop_clone(node)
@@ -303,7 +399,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
303399
partition,
304400
reporter,
305401
)
306-
tags.remove(tag)
402+
if tag in tags:
403+
tags.remove(tag)
307404
return tags
308405

309406
def partition(self, exported_program: ExportedProgram) -> PartitionResult:

0 commit comments

Comments
 (0)