Skip to content

Commit 73425f1

Browse files
fix: check both partitions for cycles in GroupBasedPartitioner._can_merge_partitions
The previous implementation only checked downstream dependencies from p2, assuming p2 always precedes p1 in topological order. This assumption breaks when partition groups contain nodes spanning wide topological ranges — for example, when dynamic quantization inserts a shared `choose_qparams` node consumed by GEMM ops in different sequential transformer decoder layers. In that case the two groups *interleave* in topological order, and the single-direction check misses cycles flowing from p1 through external nodes back into p2. This change: 1. Collects external users from *both* p1 and p2 (combined_nodes) instead of only p2. 2. Adds a `validate_partition` safety net that performs a direct BFS on the live graph edges, catching any cycle the pre-computed `_DependencyViewer` might miss. Fixes `AssertionError: Invalid partition, found dependency cycles` when lowering cross-attention transformer decoders (e.g. DETR) with `XnnpackDynamicallyQuantizedPartitioner`.
1 parent 5a7c523 commit 73425f1

2 files changed

Lines changed: 132 additions & 9 deletions

File tree

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,24 @@ def __init__(
9797

9898
def _can_merge_partitions(self, p1, p2, partitions_by_id):
9999
"""Check if merging two partitions would create a cycle."""
100+
from torch.fx.passes.utils.fuser_utils import validate_partition
101+
100102
p1_nodes = set(partitions_by_id[p1].nodes.keys())
101103
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102104
combined_nodes = p1_nodes.union(p2_nodes)
103105

106+
# Check external users from BOTH partitions. The original code only
107+
# checked p2 under the assumption that p2 is always topologically
108+
# before p1. However, when partition groups contain nodes that span
109+
# wide topological ranges (e.g. due to shared dynamic-quantization
110+
# choose_qparams nodes), the two partitions can *interleave* in
111+
# topological order, making the single-direction check insufficient.
112+
#
113+
# We still only need to collect the *direct* external users (not
114+
# transitive ones), because dependency_viewer.downstreams_of already
115+
# returns the full transitive closure.
104116
user_nodes = []
105-
# topologically, p2_nodes comes before p1_nodes, so we only
106-
# need to check the downstream nodes of p2.
107-
# Additionally, we don't need to check all the downstream nodes
108-
# of p2, we only need to check the nodes directly outside of p2.
109-
# example:
110-
# partition[a --> b --> c] --> d --> e --> f
111-
# we don't need to check [d, e, f] we only need to check [d] because
112-
# the downstream users of [d] will include [e, f]
113-
for node in p2_nodes:
117+
for node in combined_nodes:
114118
for user in node.users:
115119
if user not in combined_nodes:
116120
user_nodes.append(user)
@@ -121,6 +125,13 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
121125
if any(n in combined_nodes for n in downstream_nodes):
122126
return False
123127

128+
# Final safety net: validate_partition performs a direct BFS on the
129+
# live graph edges, catching any cycle the pre-computed
130+
# dependency_viewer might miss (e.g. when the graph was transformed
131+
# after the viewer was built).
132+
if not validate_partition(list(combined_nodes)):
133+
return False
134+
124135
return True
125136

126137
def _process_all_nodes(

exir/backend/test/test_group_partitioner.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,3 +1672,115 @@ def forward(self, x):
16721672

16731673
# With allows_single_node_partition=True, we should have partitions
16741674
self.assertGreater(len(partitions_with_single), 0)
1675+
1676+
def test_interleaved_groups_no_false_merge(self):
1677+
"""
1678+
Test that _can_merge_partitions correctly rejects merges when two
1679+
partition groups interleave in topological order.
1680+
1681+
This reproduces a real-world failure with
1682+
XnnpackDynamicallyQuantizedPartitioner on transformer decoder models
1683+
where cross-attention K/V projections across multiple decoder layers
1684+
share the same encoder ``memory`` input. Dynamic quantization inserts
1685+
a shared ``choose_qparams`` node for that input, causing the DSJ phase
1686+
to create partition groups whose nodes span wide topological ranges.
1687+
When GroupBasedPartitioner later tries to merge these groups, the
1688+
original single-direction downstream check missed the cycle because it
1689+
assumed p2 is entirely before p1 — which is false for interleaved
1690+
groups.
1691+
1692+
The model is a minimal two-layer cross-attention decoder:
1693+
1694+
.. code-block:: text
1695+
1696+
query ──→ layer0(query, memory) ──→ layer1(x, memory) ──→ output
1697+
↑ ↑
1698+
memory ───────────┴─────────────────────────┘
1699+
(shared K/V input across layers)
1700+
"""
1701+
import math
1702+
1703+
class DecoderLayer(torch.nn.Module):
1704+
def __init__(self, d: int = 256):
1705+
super().__init__()
1706+
self.q_proj = torch.nn.Linear(d, d, bias=False)
1707+
self.k_proj = torch.nn.Linear(d, d, bias=False)
1708+
self.v_proj = torch.nn.Linear(d, d, bias=False)
1709+
self.out_proj = torch.nn.Linear(d, d, bias=False)
1710+
self.ffn1 = torch.nn.Linear(d, d * 2, bias=False)
1711+
self.ffn2 = torch.nn.Linear(d * 2, d, bias=False)
1712+
self.norm1 = torch.nn.LayerNorm(d)
1713+
self.norm2 = torch.nn.LayerNorm(d)
1714+
1715+
def forward(
1716+
self, x: torch.Tensor, mem: torch.Tensor
1717+
) -> torch.Tensor:
1718+
q = self.q_proj(x)
1719+
k = self.k_proj(mem)
1720+
v = self.v_proj(mem)
1721+
attn = torch.bmm(
1722+
q, k.transpose(-2, -1)
1723+
) / math.sqrt(q.size(-1))
1724+
attn = torch.softmax(attn, dim=-1)
1725+
out = self.out_proj(torch.bmm(attn, v))
1726+
x = self.norm1(x + out)
1727+
x = self.norm2(
1728+
x + self.ffn2(torch.relu(self.ffn1(x)))
1729+
)
1730+
return x
1731+
1732+
class TwoLayerDecoder(torch.nn.Module):
1733+
def __init__(self):
1734+
super().__init__()
1735+
self.layer0 = DecoderLayer()
1736+
self.layer1 = DecoderLayer()
1737+
1738+
def forward(
1739+
self, query: torch.Tensor, memory: torch.Tensor
1740+
) -> torch.Tensor:
1741+
x = self.layer0(query, memory)
1742+
x = self.layer1(x, memory)
1743+
return x
1744+
1745+
from torch.ao.quantization.quantize_pt2e import (
1746+
convert_pt2e,
1747+
prepare_pt2e,
1748+
)
1749+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
1750+
XNNPACKQuantizer,
1751+
get_symmetric_quantization_config,
1752+
)
1753+
1754+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1755+
XnnpackDynamicallyQuantizedPartitioner,
1756+
)
1757+
from executorch.exir import to_edge_transform_and_lower
1758+
1759+
model = TwoLayerDecoder().eval()
1760+
query = torch.randn(1, 10, 256)
1761+
memory = torch.randn(1, 20, 256)
1762+
1763+
exported = torch.export.export(
1764+
model, (query, memory), strict=False
1765+
)
1766+
1767+
quantizer = XNNPACKQuantizer().set_global(
1768+
get_symmetric_quantization_config(
1769+
is_per_channel=True, is_dynamic=True
1770+
)
1771+
)
1772+
prepared = prepare_pt2e(exported.module(), quantizer)
1773+
with torch.no_grad():
1774+
prepared(query, memory)
1775+
converted = convert_pt2e(prepared)
1776+
1777+
re_exported = torch.export.export(
1778+
converted, (query, memory), strict=False
1779+
)
1780+
1781+
# Before the fix this raised:
1782+
# AssertionError: Invalid partition, found dependency cycles
1783+
to_edge_transform_and_lower(
1784+
re_exported,
1785+
partitioner=[XnnpackDynamicallyQuantizedPartitioner()],
1786+
)

0 commit comments

Comments
 (0)