Commit 3998693
authored
Fix dependency cycle in GroupBasedPartitioner._can_merge_partitions (#18397)
## Summary
`GroupBasedPartitioner._can_merge_partitions()` only checks downstream
dependencies from `p2`, assuming `p2` is always topologically before
`p1`. This assumption fails when partition groups contain nodes spanning
wide topological ranges, causing false-negative cycle detection and
ultimately `AssertionError: Invalid partition, found dependency cycles`
at `fuse_as_graphmodule` time.
**Root cause:** Dynamic quantization inserts `choose_qparams` nodes that
are shared across multiple GEMM ops consuming the same activation. The
DSJ (Disjoint Set Join) phase merges these ops into groups whose nodes
*interleave* in topological order. When `_merge_partitions` later tries
to combine two such interleaved groups, the single-direction check (p2
only) misses the cycle path from p1 → external → p2.
**Fix:**
1. Check external users from **both** `p1` and `p2` (`combined_nodes`)
instead of only `p2`.
2. Add a `validate_partition()` safety net (BFS on live graph edges) to
catch any cycle the pre-computed `_DependencyViewer` might miss.
## Reproduction
The issue is triggered when lowering a cross-attention transformer
decoder with `XnnpackDynamicallyQuantizedPartitioner`. Multiple decoder
layers share the same encoder output for K/V projections, causing
`choose_qparams` sharing → DSJ group interleaving → false merge →
dependency cycle.
Minimal reproduction (no external dependencies beyond PyTorch +
ExecuTorch):
```python
import math, torch, torch.nn as nn
class DecoderLayer(nn.Module):
def __init__(self, d=256):
super().__init__()
self.q_proj = nn.Linear(d, d, bias=False)
self.k_proj = nn.Linear(d, d, bias=False)
self.v_proj = nn.Linear(d, d, bias=False)
self.out_proj = nn.Linear(d, d, bias=False)
self.ffn1 = nn.Linear(d, d * 2, bias=False)
self.ffn2 = nn.Linear(d * 2, d, bias=False)
self.norm1 = nn.LayerNorm(d)
self.norm2 = nn.LayerNorm(d)
def forward(self, x, mem):
q, k, v = self.q_proj(x), self.k_proj(mem), self.v_proj(mem)
attn = torch.softmax(torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)), dim=-1)
x = self.norm1(x + self.out_proj(torch.bmm(attn, v)))
return self.norm2(x + self.ffn2(torch.relu(self.ffn1(x))))
class TwoLayerDecoder(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = DecoderLayer()
self.layer1 = DecoderLayer()
def forward(self, query, memory):
return self.layer1(self.layer0(query, memory), memory)
# Export → dynamic quant → lower
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackDynamicallyQuantizedPartitioner
from executorch.exir import to_edge_transform_and_lower
model = TwoLayerDecoder().eval()
q, m = torch.randn(1, 10, 256), torch.randn(1, 20, 256)
exported = torch.export.export(model, (q, m), strict=False)
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True))
prepared = prepare_pt2e(exported.module(), quantizer)
with torch.no_grad(): prepared(q, m)
converted = convert_pt2e(prepared)
re_exported = torch.export.export(converted, (q, m), strict=False)
# Before fix: AssertionError: Invalid partition, found dependency cycles
to_edge_transform_and_lower(re_exported, partitioner=[XnnpackDynamicallyQuantizedPartitioner()])
```
## Test plan
- [x] Added `test_interleaved_groups_no_false_merge` in
`exir/backend/test/test_group_partitioner.py`
- [x] Verified the test fails without the fix and passes with the fix
- [ ] Existing `test_group_partitioner.py` tests pass
cc @JacobSzwejbka @angelayi @GregoryComer @digantdesai @cbilgin1 parent ed1c88b commit 3998693
3 files changed
Lines changed: 120 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
97 | 97 | | |
98 | 98 | | |
99 | 99 | | |
| 100 | + | |
| 101 | + | |
100 | 102 | | |
101 | 103 | | |
102 | 104 | | |
103 | 105 | | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
104 | 116 | | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
| 117 | + | |
114 | 118 | | |
115 | 119 | | |
116 | 120 | | |
| |||
121 | 125 | | |
122 | 126 | | |
123 | 127 | | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
124 | 135 | | |
125 | 136 | | |
126 | 137 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
426 | 426 | | |
427 | 427 | | |
428 | 428 | | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
429 | 433 | | |
430 | 434 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1672 | 1672 | | |
1673 | 1673 | | |
1674 | 1674 | | |
| 1675 | + | |
| 1676 | + | |
| 1677 | + | |
| 1678 | + | |
| 1679 | + | |
| 1680 | + | |
| 1681 | + | |
| 1682 | + | |
| 1683 | + | |
| 1684 | + | |
| 1685 | + | |
| 1686 | + | |
| 1687 | + | |
| 1688 | + | |
| 1689 | + | |
| 1690 | + | |
| 1691 | + | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
| 1695 | + | |
| 1696 | + | |
| 1697 | + | |
| 1698 | + | |
| 1699 | + | |
| 1700 | + | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| 1705 | + | |
| 1706 | + | |
| 1707 | + | |
| 1708 | + | |
| 1709 | + | |
| 1710 | + | |
| 1711 | + | |
| 1712 | + | |
| 1713 | + | |
| 1714 | + | |
| 1715 | + | |
| 1716 | + | |
| 1717 | + | |
| 1718 | + | |
| 1719 | + | |
| 1720 | + | |
| 1721 | + | |
| 1722 | + | |
| 1723 | + | |
| 1724 | + | |
| 1725 | + | |
| 1726 | + | |
| 1727 | + | |
| 1728 | + | |
| 1729 | + | |
| 1730 | + | |
| 1731 | + | |
| 1732 | + | |
| 1733 | + | |
| 1734 | + | |
| 1735 | + | |
| 1736 | + | |
| 1737 | + | |
| 1738 | + | |
| 1739 | + | |
| 1740 | + | |
| 1741 | + | |
| 1742 | + | |
| 1743 | + | |
| 1744 | + | |
| 1745 | + | |
| 1746 | + | |
| 1747 | + | |
| 1748 | + | |
| 1749 | + | |
| 1750 | + | |
| 1751 | + | |
| 1752 | + | |
| 1753 | + | |
| 1754 | + | |
| 1755 | + | |
| 1756 | + | |
| 1757 | + | |
| 1758 | + | |
| 1759 | + | |
| 1760 | + | |
| 1761 | + | |
| 1762 | + | |
| 1763 | + | |
| 1764 | + | |
| 1765 | + | |
| 1766 | + | |
| 1767 | + | |
| 1768 | + | |
| 1769 | + | |
| 1770 | + | |
0 commit comments