Skip to content

Commit 3998693

Browse files
Fix dependency cycle in GroupBasedPartitioner._can_merge_partitions (#18397)
## Summary `GroupBasedPartitioner._can_merge_partitions()` only checks downstream dependencies from `p2`, assuming `p2` is always topologically before `p1`. This assumption fails when partition groups contain nodes spanning wide topological ranges, causing false-negative cycle detection and ultimately `AssertionError: Invalid partition, found dependency cycles` at `fuse_as_graphmodule` time. **Root cause:** Dynamic quantization inserts `choose_qparams` nodes that are shared across multiple GEMM ops consuming the same activation. The DSJ (Disjoint Set Join) phase merges these ops into groups whose nodes *interleave* in topological order. When `_merge_partitions` later tries to combine two such interleaved groups, the single-direction check (p2 only) misses the cycle path from p1 → external → p2. **Fix:** 1. Check external users from **both** `p1` and `p2` (`combined_nodes`) instead of only `p2`. 2. Add a `validate_partition()` safety net (BFS on live graph edges) to catch any cycle the pre-computed `_DependencyViewer` might miss. ## Reproduction The issue is triggered when lowering a cross-attention transformer decoder with `XnnpackDynamicallyQuantizedPartitioner`. Multiple decoder layers share the same encoder output for K/V projections, causing `choose_qparams` sharing → DSJ group interleaving → false merge → dependency cycle. Minimal reproduction (no external dependencies beyond PyTorch + ExecuTorch): ```python import math, torch, torch.nn as nn class DecoderLayer(nn.Module): def __init__(self, d=256): super().__init__() self.q_proj = nn.Linear(d, d, bias=False) self.k_proj = nn.Linear(d, d, bias=False) self.v_proj = nn.Linear(d, d, bias=False) self.out_proj = nn.Linear(d, d, bias=False) self.ffn1 = nn.Linear(d, d * 2, bias=False) self.ffn2 = nn.Linear(d * 2, d, bias=False) self.norm1 = nn.LayerNorm(d) self.norm2 = nn.LayerNorm(d) def forward(self, x, mem): q, k, v = self.q_proj(x), self.k_proj(mem), self.v_proj(mem) attn = torch.softmax(torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)), dim=-1) x = self.norm1(x + self.out_proj(torch.bmm(attn, v))) return self.norm2(x + self.ffn2(torch.relu(self.ffn1(x)))) class TwoLayerDecoder(nn.Module): def __init__(self): super().__init__() self.layer0 = DecoderLayer() self.layer1 = DecoderLayer() def forward(self, query, memory): return self.layer1(self.layer0(query, memory), memory) # Export → dynamic quant → lower from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackDynamicallyQuantizedPartitioner from executorch.exir import to_edge_transform_and_lower model = TwoLayerDecoder().eval() q, m = torch.randn(1, 10, 256), torch.randn(1, 20, 256) exported = torch.export.export(model, (q, m), strict=False) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)) prepared = prepare_pt2e(exported.module(), quantizer) with torch.no_grad(): prepared(q, m) converted = convert_pt2e(prepared) re_exported = torch.export.export(converted, (q, m), strict=False) # Before fix: AssertionError: Invalid partition, found dependency cycles to_edge_transform_and_lower(re_exported, partitioner=[XnnpackDynamicallyQuantizedPartitioner()]) ``` ## Test plan - [x] Added `test_interleaved_groups_no_false_merge` in `exir/backend/test/test_group_partitioner.py` - [x] Verified the test fails without the fix and passes with the fix - [ ] Existing `test_group_partitioner.py` tests pass cc @JacobSzwejbka @angelayi @GregoryComer @digantdesai @cbilgin
1 parent ed1c88b commit 3998693

3 files changed

Lines changed: 120 additions & 9 deletions

File tree

exir/backend/canonical_partitioners/group_partitioner.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,24 @@ def __init__(
9797

9898
def _can_merge_partitions(self, p1, p2, partitions_by_id):
9999
"""Check if merging two partitions would create a cycle."""
100+
from torch.fx.passes.utils.fuser_utils import validate_partition
101+
100102
p1_nodes = set(partitions_by_id[p1].nodes.keys())
101103
p2_nodes = set(partitions_by_id[p2].nodes.keys())
102104
combined_nodes = p1_nodes.union(p2_nodes)
103105

106+
# Check external users from BOTH partitions. The original code only
107+
# checked p2 under the assumption that p2 is always topologically
108+
# before p1. However, when partition groups contain nodes that span
109+
# wide topological ranges (e.g. due to shared dynamic-quantization
110+
# choose_qparams nodes), the two partitions can *interleave* in
111+
# topological order, making the single-direction check insufficient.
112+
#
113+
# We still only need to collect the *direct* external users (not
114+
# transitive ones), because dependency_viewer.downstreams_of already
115+
# returns the full transitive closure.
104116
user_nodes = []
105-
# topologically, p2_nodes comes before p1_nodes, so we only
106-
# need to check the downstream nodes of p2.
107-
# Additionally, we don't need to check all the downstream nodes
108-
# of p2, we only need to check the nodes directly outside of p2.
109-
# example:
110-
# partition[a --> b --> c] --> d --> e --> f
111-
# we don't need to check [d, e, f] we only need to check [d] because
112-
# the downstream users of [d] will include [e, f]
113-
for node in p2_nodes:
117+
for node in combined_nodes:
114118
for user in node.users:
115119
if user not in combined_nodes:
116120
user_nodes.append(user)
@@ -121,6 +125,13 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
121125
if any(n in combined_nodes for n in downstream_nodes):
122126
return False
123127

128+
# Final safety net: validate_partition performs a direct BFS on the
129+
# live graph edges, catching any cycle the pre-computed
130+
# dependency_viewer might miss (e.g. when the graph was transformed
131+
# after the viewer was built).
132+
if not validate_partition(list(combined_nodes)):
133+
return False
134+
124135
return True
125136

126137
def _process_all_nodes(

exir/backend/test/BUCK

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,5 +426,9 @@ fbcode_target(_kind = runtime.python_test,
426426
deps = [
427427
"//caffe2:torch",
428428
"//executorch/exir/backend/canonical_partitioners:group_partitioner_lib",
429+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
430+
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
431+
"//executorch/exir:lib",
432+
"//pytorch/ao:torchao",
429433
],
430434
)

exir/backend/test/test_group_partitioner.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,3 +1672,99 @@ def forward(self, x):
16721672

16731673
# With allows_single_node_partition=True, we should have partitions
16741674
self.assertGreater(len(partitions_with_single), 0)
1675+
1676+
def test_interleaved_groups_no_false_merge(self):
1677+
"""
1678+
Test that _can_merge_partitions correctly rejects merges when two
1679+
partition groups interleave in topological order.
1680+
1681+
This reproduces a real-world failure with
1682+
XnnpackDynamicallyQuantizedPartitioner on transformer decoder models
1683+
where cross-attention K/V projections across multiple decoder layers
1684+
share the same encoder ``memory`` input. Dynamic quantization inserts
1685+
a shared ``choose_qparams`` node for that input, causing the DSJ phase
1686+
to create partition groups whose nodes span wide topological ranges.
1687+
When GroupBasedPartitioner later tries to merge these groups, the
1688+
original single-direction downstream check missed the cycle because it
1689+
assumed p2 is entirely before p1 — which is false for interleaved
1690+
groups.
1691+
1692+
The model is a minimal two-layer cross-attention decoder:
1693+
1694+
.. code-block:: text
1695+
1696+
query ──→ layer0(query, memory) ──→ layer1(x, memory) ──→ output
1697+
↑ ↑
1698+
memory ───────────┴─────────────────────────┘
1699+
(shared K/V input across layers)
1700+
"""
1701+
import math
1702+
1703+
class DecoderLayer(torch.nn.Module):
1704+
def __init__(self, d: int = 256):
1705+
super().__init__()
1706+
self.q_proj = torch.nn.Linear(d, d, bias=False)
1707+
self.k_proj = torch.nn.Linear(d, d, bias=False)
1708+
self.v_proj = torch.nn.Linear(d, d, bias=False)
1709+
self.out_proj = torch.nn.Linear(d, d, bias=False)
1710+
self.ffn1 = torch.nn.Linear(d, d * 2, bias=False)
1711+
self.ffn2 = torch.nn.Linear(d * 2, d, bias=False)
1712+
self.norm1 = torch.nn.LayerNorm(d)
1713+
self.norm2 = torch.nn.LayerNorm(d)
1714+
1715+
def forward(self, x: torch.Tensor, mem: torch.Tensor) -> torch.Tensor:
1716+
q = self.q_proj(x)
1717+
k = self.k_proj(mem)
1718+
v = self.v_proj(mem)
1719+
attn = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
1720+
attn = torch.softmax(attn, dim=-1)
1721+
out = self.out_proj(torch.bmm(attn, v))
1722+
x = self.norm1(x + out)
1723+
x = self.norm2(x + self.ffn2(torch.relu(self.ffn1(x))))
1724+
return x
1725+
1726+
class TwoLayerDecoder(torch.nn.Module):
1727+
def __init__(self):
1728+
super().__init__()
1729+
self.layer0 = DecoderLayer()
1730+
self.layer1 = DecoderLayer()
1731+
1732+
def forward(
1733+
self, query: torch.Tensor, memory: torch.Tensor
1734+
) -> torch.Tensor:
1735+
x = self.layer0(query, memory)
1736+
x = self.layer1(x, memory)
1737+
return x
1738+
1739+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1740+
XnnpackDynamicallyQuantizedPartitioner,
1741+
)
1742+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
1743+
get_symmetric_quantization_config,
1744+
XNNPACKQuantizer,
1745+
)
1746+
from executorch.exir import to_edge_transform_and_lower
1747+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
1748+
1749+
model = TwoLayerDecoder().eval()
1750+
query = torch.randn(1, 10, 256)
1751+
memory = torch.randn(1, 20, 256)
1752+
1753+
exported = torch.export.export(model, (query, memory), strict=False)
1754+
1755+
quantizer = XNNPACKQuantizer().set_global(
1756+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
1757+
)
1758+
prepared = prepare_pt2e(exported.module(), quantizer)
1759+
with torch.no_grad():
1760+
prepared(query, memory)
1761+
converted = convert_pt2e(prepared)
1762+
1763+
re_exported = torch.export.export(converted, (query, memory), strict=False)
1764+
1765+
# Before the fix this raised:
1766+
# AssertionError: Invalid partition, found dependency cycles
1767+
to_edge_transform_and_lower(
1768+
re_exported,
1769+
partitioner=[XnnpackDynamicallyQuantizedPartitioner()],
1770+
)

0 commit comments

Comments
 (0)