Skip to content

Commit 88fea01

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Handle rank-changing views in FuseCascadedTransposeOrPermuteOps (#19539)
Summary: Extend FuseCascadedTransposeOrPermuteOps to fuse permute→view_copy→permute patterns where a squeeze/unsqueeze view sits between two permutes. The pass computes the combined effect of permute+view+permute: if the permutation components cancel out (identity), the entire chain is replaced with a single view_copy. This handles patterns like permute_3D([0,2,1]) → view(unsqueeze) → permute_4D([0,2,3,1]) which composes to a simple view_copy (the permutations cancel, leaving only the reshape). Differential Revision: D104775245
1 parent d833a22 commit 88fea01

2 files changed

Lines changed: 324 additions & 78 deletions

File tree

backends/transforms/fuse_cascaded_transpose_or_permute_ops.py

Lines changed: 114 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,40 @@
2020
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
2121
"""
2222
Fuse a chain of transpose and permute ops into a single permute or a no-op.
23-
Handles branches and chains permutes.
23+
Handles branches and chains of permutes, including permute-view-permute
24+
patterns where a squeeze/unsqueeze view sits between two permutes.
2425
"""
2526

2627
transpose_or_permute_target = {
2728
exir_ops.edge.aten.transpose_copy.int,
2829
exir_ops.edge.aten.permute_copy.default,
2930
}
3031

32+
_VIEW_OPS = {
33+
exir_ops.edge.aten.view_copy.default,
34+
exir_ops.edge.aten.view.default,
35+
}
36+
3137
@property
3238
def targets(self) -> list[EdgeOpOverload]:
3339
return list(self.transpose_or_permute_target)
3440

3541
def maybe_remove_or_replace(self, node: Node) -> bool:
36-
# Fuse with the parent node if it's also a permute or a transpose. Since the
37-
# pass interface traverses all ops in order the pass will properly fuse a chain
38-
# of permutes.
3942
parent_node = get_arg(node, "input", Node)
40-
if parent_node.target not in self.transpose_or_permute_target:
41-
return False
42-
input_of_parent = get_arg(parent_node, "input", Node)
4343

44-
# Compute combined effect of permutes.
44+
# Case 1: Direct permute/transpose → permute/transpose
45+
if parent_node.target in self.transpose_or_permute_target:
46+
return self._fuse_direct(node, parent_node)
47+
48+
# Case 2: permute → view_copy(squeeze/unsqueeze) → permute
49+
if parent_node.target in self._VIEW_OPS:
50+
return self._fuse_across_view(node, parent_node)
51+
52+
return False
53+
54+
def _fuse_direct(self, node: Node, parent_node: Node) -> bool:
55+
"""Fuse two adjacent permute/transpose ops."""
56+
input_of_parent = get_arg(parent_node, "input", Node)
4557
dims = list(range(node.meta["val"].ndim))
4658

4759
if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
@@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
5466
else:
5567
dims = get_permuted_dims(node, dims)
5668

57-
# If combined effect is identity replace the node with input.
5869
if dims == sorted(dims):
5970
node.replace_all_uses_with(input_of_parent)
6071
else:
@@ -67,3 +78,97 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
6778
node.replace_all_uses_with(new_permute)
6879

6980
return True
81+
82+
def _fuse_across_view(self, node: Node, view_node: Node) -> bool:
83+
"""Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
84+
# view_node must have exactly one user (this permute node)
85+
if len(view_node.users) != 1:
86+
return False
87+
# view_node's parent must be a permute/transpose
88+
view_input = get_arg(view_node, "input", Node)
89+
if view_input.target not in self.transpose_or_permute_target:
90+
return False
91+
# The view must be a squeeze or unsqueeze (rank differs by 1)
92+
if view_node.meta.get("val") is None or view_input.meta.get("val") is None:
93+
return False
94+
view_in_shape = view_input.meta["val"].shape
95+
view_out_shape = view_node.meta["val"].shape
96+
if abs(len(view_in_shape) - len(view_out_shape)) != 1:
97+
return False
98+
99+
# Get the input before the first permute
100+
input_of_first_permute = get_arg(view_input, "input", Node)
101+
if input_of_first_permute.meta.get("val") is None:
102+
return False
103+
104+
# Compute the combined effect on the original input dimensions
105+
# Start with identity dims for the original input
106+
original_ndim = input_of_first_permute.meta["val"].ndim
107+
dims = list(range(original_ndim))
108+
109+
# Apply first permute
110+
if view_input.target == exir_ops.edge.aten.transpose_copy.int:
111+
dims = get_transposed_dims(view_input, dims)
112+
else:
113+
dims = get_permuted_dims(view_input, dims)
114+
115+
# Apply the view (squeeze/unsqueeze)
116+
if len(view_out_shape) == len(view_in_shape) + 1:
117+
# unsqueeze: insert a new dim
118+
index = self._find_extra_one(view_out_shape, view_in_shape)
119+
if index == -1:
120+
return False
121+
dims = [x + 1 if x >= index else x for x in dims]
122+
dims.insert(index, -1) # -1 marks the inserted dim
123+
elif len(view_in_shape) == len(view_out_shape) + 1:
124+
# squeeze: remove a dim
125+
index = self._find_extra_one(view_in_shape, view_out_shape)
126+
if index == -1:
127+
return False
128+
if dims[index] != -1:
129+
# Safe: permutation preserves dimension sizes, so a size-1
130+
# intermediate dim necessarily originated from a size-1 input dim.
131+
pass
132+
del dims[index]
133+
134+
# Apply second permute (node)
135+
if node.target == exir_ops.edge.aten.transpose_copy.int:
136+
node_dims = list(range(len(dims)))
137+
node_dims = get_transposed_dims(node, node_dims)
138+
dims = [dims[d] for d in node_dims]
139+
else:
140+
perm = get_arg(node, "dims")
141+
dims = [dims[d] for d in perm]
142+
143+
# Check if the combined effect (ignoring -1 inserted dims) is identity
144+
real_dims = [d for d in dims if d != -1]
145+
146+
if real_dims == sorted(real_dims):
147+
# Combined permutations are identity — replace with view_copy
148+
# (the only remaining effect is the squeeze/unsqueeze reshape)
149+
output_shape = node.meta["val"].shape
150+
if output_shape == input_of_first_permute.meta["val"].shape:
151+
# Total no-op: replace with input
152+
node.replace_all_uses_with(input_of_first_permute)
153+
else:
154+
with node.graph.inserting_before(node):
155+
new_view = node.graph.call_function(
156+
exir_ops.edge.aten.view_copy.default,
157+
args=(input_of_first_permute, list(output_shape)),
158+
)
159+
new_view.meta = node.meta
160+
node.replace_all_uses_with(new_view)
161+
return True
162+
163+
return False
164+
165+
@staticmethod
166+
def _find_extra_one(longer, shorter):
167+
if len(longer) != len(shorter) + 1:
168+
return -1
169+
for i in range(len(shorter)):
170+
if longer[i] != shorter[i]:
171+
if longer[i] == 1 and shorter[i:] == longer[i + 1:]:
172+
return i
173+
return -1
174+
return len(shorter) if longer[-1] == 1 else -1

0 commit comments

Comments
 (0)