Arm backend: fix(arm): validate partitions for dependency cycles after Q/DQ de-tagging#18191
Arm backend: fix(arm): validate partitions for dependency cycles after Q/DQ de-tagging#18191beomwookang wants to merge 4 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18191
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
Hi @beomwookang, could you resolve the failing CI tests and check the lintrunner? Do let us know here once done for a review. Thanks for the bug report and the PR! |
|
@beomwookang pinging again to resolve CI tests failures |
…ging `_detag_boundary_nodes` removes Q/DQ nodes from partition boundaries after `CapabilityBasedPartitioner` has produced cycle-free partitions. However, this de-tagging can introduce dependency cycles for models with complex attention blocks (e.g. MobileViT, where CNN and Transformer ops are grouped into a single large partition). The cycle occurs because removing Q/DQ bridge nodes creates paths that exit the partition and re-enter it through the now-unpartitioned nodes, making it impossible to extract the partition as a valid subgraph. This change adds cycle validation after `_detag_boundary_nodes`. When a cycle is detected, the partition is split into connected components of the surviving (still-tagged) nodes. Each component becomes a separate partition that is individually cycle-free after de-tagging. - Add `_validate_partition()`: BFS-based cycle detection (same algorithm as `torch.fx.passes.utils.fuser_utils.validate_partition`) - Add `_find_connected_components()`: undirected graph traversal to split surviving nodes into disjoint sub-partitions - Guard the nocompute-partition `tags.remove()` against already-removed tags from the cycle-split path Tested with MobileViT-S on Ethos-U85: previously failed with `AssertionError: Invalid partition, found dependency cycles`, now successfully produces a .pte file (5.7 MB). Nine attention-block partitions are each split into 3 sub-partitions. All sub-partitions remain on NPU (no CPU fallback). Existing CNN-only models (ResNet, MobileNetV2, EfficientNet) are unaffected as their partitions have no cycles after de-tagging.
5f597ab to
c94b062
Compare
|
@nil-is-all All remaining CI failures on this PR are pre-existing issues unrelated to the backends/arm/tosa/partitioner.py change:
This PR touches only |
|
@claude Review this PR |
|
Claude finished @digantdesai's task in 3m 35s —— View job Review of PR #18191
SummaryThis PR fixes a real bug where Algorithm Correctness
Issues1. Nocompute-partition check doesn't cover sub-partitions after cycle split (moderate) After the cycle-split path (lines 360-382), the nocompute check at lines 384-403 still iterates over the original
In practice, this means the nocompute check after a cycle split is partially broken: it may incorrectly strip tags from newly created sub-partitions without removing their corresponding tags from the 2. At line 366, 3. Connected components don't guarantee cycle-free sub-partitions (theoretical) Splitting by connected components guarantees that each component is a connected subgraph, but does not strictly guarantee each component is cycle-free. In practice, a single connected component that is contiguous in the directed graph should be extractable without creating external cycles — the cycles are caused by disconnected groups of nodes within the same partition that have paths through external nodes between them. So splitting by connectivity is the right fix here. But it may be worth adding a debug assertion to verify each component passes Style Observations
Missing TestsThere are no unit tests for VerdictThe core fix is correct and addresses a real issue with CNN+Transformer hybrid models on the ARM backend. The two new utility functions are well-implemented. I'd recommend addressing the nocompute check interaction (issue #1) and adding a unit test before merging. The guard change at line 402 ( |
…ions After a cycle split, the nocompute check still iterated over the original partition nodes and only attempted to remove the original tag. This left orphan tags in the returned set when sub-partitions were rejected. Group the nocompute check by active tag so each sub-partition is evaluated and cleaned up independently. Also update reject_partition() to accept an iterable of nodes instead of a Partition object.
Add tests for _validate_partition and _find_connected_components using synthetic torch.fx graphs. Cover contiguous/non-contiguous partitions, single nodes, branching graphs, and empty sets.
|
@digantdesai I reflected the review from Claude. Please check:) |
Fixes #18190
Summary
_detag_boundary_nodes()removes Q/DQ nodes from partition boundaries afterCapabilityBasedPartitionerhas produced cycle-free partitions.However, this de-tagging can introduce dependency cycles for models with complex attention blocks (e.g. MobileViT, where CNN and Transformer ops are grouped into a single large partition).
The cycle occurs because removing Q/DQ bridge nodes creates paths that exit the partition and re-enter it through the now-unpartitioned nodes, making it impossible to extract the partition as a valid subgraph:
[partition] Linear_Q → [de-tagged Q] → [outside] → [de-tagged DQ] → [partition] Matmul
[partition] Linear_K → [de-tagged Q] → [outside] → [de-tagged DQ] → [partition] Matmul
Changes
After
_detag_boundary_nodes(), validate each partition for dependency cycles.When a cycle is detected, split the partition into connected components of the surviving tagged nodes.
Each component becomes a separate valid partition.
_validate_partition(): BFS-based cycle detection (same algorithm astorch.fx.passes.utils.fuser_utils.validate_partition)_find_connected_components(): undirected graph traversal to split surviving nodes into disjoint sub-partitionstags.remove()against already-removed tags from the cycle-split pathTest Results
AssertionError: Invalid partition, found dependency cycles, now successfully produces a .pte file (5.7 MB). Nine attention-block partitions are each split into 3 sub-partitions. All sub-partitions remain on NPU (no CPU fallback).How to test