Skip to content

Commit 7699a00

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Handle rank-changing views in FuseCascadedTransposeOrPermuteOps (#19539)
Summary: Extend FuseCascadedTransposeOrPermuteOps to fuse permute→view_copy→permute patterns where a squeeze/unsqueeze view sits between two permutes. The pass computes the combined effect of permute+view+permute: if the permutation components cancel out (identity), the entire chain is replaced with a single view_copy. This handles patterns like permute_3D([0,2,1]) → view(unsqueeze) → permute_4D([0,2,3,1]) which composes to a simple view_copy (the permutations cancel, leaving only the reshape). Differential Revision: D104775245
1 parent bb8197e commit 7699a00

2 files changed

Lines changed: 210 additions & 78 deletions

File tree

backends/transforms/fuse_cascaded_transpose_or_permute_ops.py

Lines changed: 114 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,40 @@
2020
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
2121
"""
2222
Fuse a chain of transpose and permute ops into a single permute or a no-op.
23-
Handles branches and chains permutes.
23+
Handles branches and chains of permutes, including permute-view-permute
24+
patterns where a squeeze/unsqueeze view sits between two permutes.
2425
"""
2526

2627
transpose_or_permute_target = {
2728
exir_ops.edge.aten.transpose_copy.int,
2829
exir_ops.edge.aten.permute_copy.default,
2930
}
3031

32+
_VIEW_OPS = {
33+
exir_ops.edge.aten.view_copy.default,
34+
exir_ops.edge.aten.view.default,
35+
}
36+
3137
@property
3238
def targets(self) -> list[EdgeOpOverload]:
3339
return list(self.transpose_or_permute_target)
3440

3541
def maybe_remove_or_replace(self, node: Node) -> bool:
36-
# Fuse with the parent node if it's also a permute or a transpose. Since the
37-
# pass interface traverses all ops in order the pass will properly fuse a chain
38-
# of permutes.
3942
parent_node = get_arg(node, "input", Node)
40-
if parent_node.target not in self.transpose_or_permute_target:
41-
return False
42-
input_of_parent = get_arg(parent_node, "input", Node)
4343

44-
# Compute combined effect of permutes.
44+
# Case 1: Direct permute/transpose → permute/transpose
45+
if parent_node.target in self.transpose_or_permute_target:
46+
return self._fuse_direct(node, parent_node)
47+
48+
# Case 2: permute → view_copy(squeeze/unsqueeze) → permute
49+
if parent_node.target in self._VIEW_OPS:
50+
return self._fuse_across_view(node, parent_node)
51+
52+
return False
53+
54+
def _fuse_direct(self, node: Node, parent_node: Node) -> bool:
55+
"""Fuse two adjacent permute/transpose ops."""
56+
input_of_parent = get_arg(parent_node, "input", Node)
4557
dims = list(range(node.meta["val"].ndim))
4658

4759
if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
@@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
5466
else:
5567
dims = get_permuted_dims(node, dims)
5668

57-
# If combined effect is identity replace the node with input.
5869
if dims == sorted(dims):
5970
node.replace_all_uses_with(input_of_parent)
6071
else:
@@ -67,3 +78,97 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
6778
node.replace_all_uses_with(new_permute)
6879

6980
return True
81+
82+
def _fuse_across_view(self, node: Node, view_node: Node) -> bool:
83+
"""Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
84+
# view_node must have exactly one user (this permute node)
85+
if len(list(view_node.users.keys())) != 1:
86+
return False
87+
# view_node's parent must be a permute/transpose
88+
view_input = get_arg(view_node, "input", Node)
89+
if view_input.target not in self.transpose_or_permute_target:
90+
return False
91+
# The view must be a squeeze or unsqueeze (rank differs by 1)
92+
if view_node.meta.get("val") is None or view_input.meta.get("val") is None:
93+
return False
94+
view_in_shape = list(view_input.meta["val"].shape)
95+
view_out_shape = list(view_node.meta["val"].shape)
96+
if abs(len(view_in_shape) - len(view_out_shape)) != 1:
97+
return False
98+
99+
# Get the input before the first permute
100+
input_of_first_permute = get_arg(view_input, "input", Node)
101+
if input_of_first_permute.meta.get("val") is None:
102+
return False
103+
104+
# Compute the combined effect on the original input dimensions
105+
# Start with identity dims for the original input
106+
original_ndim = input_of_first_permute.meta["val"].ndim
107+
dims = list(range(original_ndim))
108+
109+
# Apply first permute
110+
if view_input.target == exir_ops.edge.aten.transpose_copy.int:
111+
dims = get_transposed_dims(view_input, dims)
112+
else:
113+
dims = get_permuted_dims(view_input, dims)
114+
115+
# Apply the view (squeeze/unsqueeze)
116+
if len(view_out_shape) == len(view_in_shape) + 1:
117+
# unsqueeze: insert a new dim
118+
index = self._find_extra_one(view_out_shape, view_in_shape)
119+
if index == -1:
120+
return False
121+
dims = [x + 1 if x >= index else x for x in dims]
122+
dims.insert(index, -1) # -1 marks the inserted dim
123+
elif len(view_in_shape) == len(view_out_shape) + 1:
124+
# squeeze: remove a dim
125+
index = self._find_extra_one(view_in_shape, view_out_shape)
126+
if index == -1:
127+
return False
128+
if dims[index] != -1:
129+
# Safe: permutation preserves dimension sizes, so a size-1
130+
# intermediate dim necessarily originated from a size-1 input dim.
131+
pass
132+
del dims[index]
133+
134+
# Apply second permute (node)
135+
if node.target == exir_ops.edge.aten.transpose_copy.int:
136+
node_dims = list(range(len(dims)))
137+
node_dims = get_transposed_dims(node, node_dims)
138+
dims = [dims[d] for d in node_dims]
139+
else:
140+
perm = list(node.args[1])
141+
dims = [dims[d] for d in perm]
142+
143+
# Check if the combined effect (ignoring -1 inserted dims) is identity
144+
real_dims = [d for d in dims if d != -1]
145+
output_shape = list(node.meta["val"].shape)
146+
147+
if real_dims == sorted(real_dims):
148+
# Combined permutations are identity — replace with view_copy
149+
# (the only remaining effect is the squeeze/unsqueeze reshape)
150+
if output_shape == list(input_of_first_permute.meta["val"].shape):
151+
# Total no-op: replace with input
152+
node.replace_all_uses_with(input_of_first_permute)
153+
else:
154+
with node.graph.inserting_before(node):
155+
new_view = node.graph.call_function(
156+
exir_ops.edge.aten.view_copy.default,
157+
args=(input_of_first_permute, output_shape),
158+
)
159+
new_view.meta = node.meta
160+
node.replace_all_uses_with(new_view)
161+
return True
162+
163+
return False
164+
165+
@staticmethod
166+
def _find_extra_one(longer, shorter):
167+
if len(longer) != len(shorter) + 1:
168+
return -1
169+
for i in range(len(shorter)):
170+
if longer[i] != shorter[i]:
171+
if longer[i] == 1 and shorter[i:] == longer[i + 1:]:
172+
return i
173+
return -1
174+
return len(shorter) if longer[-1] == 1 else -1

backends/transforms/test/test_permute_optimization_passes.py

Lines changed: 96 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
2020
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
2121
)
22-
from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import (
23-
ReplaceNopTransposeOrPermuteWithViewPass,
24-
)
2522
from executorch.backends.transforms.remove_permutes_around_elementwise_ops import (
2623
RemovePermutesAroundElementwiseOps,
2724
)
25+
from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import (
26+
ReplaceNopTransposeOrPermuteWithViewPass,
27+
)
2828
from executorch.exir.dialects._ops import ops as exir_ops
2929
from executorch.exir.pass_base import PassResult
3030
from torch.utils import _pytree as pytree
@@ -122,15 +122,12 @@ def test_cascaded_permutes_multiple_users(self) -> None:
122122
permute1 = builder.call_operator(
123123
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
124124
)
125-
# permute2 reverses permute1 => identity
126125
permute2 = builder.call_operator(
127126
op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 3, 1, 2])
128127
)
129-
# permute3: different permutation
130128
permute3 = builder.call_operator(
131129
op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 2, 1, 3])
132130
)
133-
# permute4 -> permute5: chained
134131
permute4 = builder.call_operator(
135132
op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [3, 2, 0, 1])
136133
)
@@ -151,6 +148,38 @@ def test_cascaded_permutes_multiple_users(self) -> None:
151148
"FuseCascadedTransposeOrPermuteOps",
152149
)
153150

151+
def test_permute_view_permute_fuse(self) -> None:
152+
"""permute_3D([0,2,1]) → view(unsqueeze) → permute_4D([0,2,3,1]) should
153+
be replaced with a single view_copy (permutations cancel out)."""
154+
builder = GraphBuilder()
155+
x_data = torch.randn(1, 40, 18)
156+
x = builder.placeholder("x", x_data)
157+
p1 = builder.call_operator(
158+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 1])
159+
)
160+
v = builder.call_operator(
161+
op=exir_ops.edge.aten.view_copy.default, args=(p1, [1, 18, 1, 40])
162+
)
163+
p2 = builder.call_operator(
164+
op=exir_ops.edge.aten.permute_copy.default, args=(v, [0, 2, 3, 1])
165+
)
166+
builder.output([p2])
167+
original = builder.get_graph_module()
168+
gm_before = copy.deepcopy(original)
169+
170+
p = FuseCascadedTransposeOrPermuteOps()
171+
result = cast(PassResult, p(original))
172+
self.assertTrue(result.modified)
173+
gm = result.graph_module
174+
175+
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
176+
self.assertGreaterEqual(
177+
count_node(gm, exir_ops.edge.aten.view_copy.default), 1
178+
)
179+
validate_numerics(
180+
gm_before, gm, [x_data], "FuseCascadedAcrossView",
181+
)
182+
154183

155184
# ──────────────────────────────────────────────────────────────────────
156185
# Tests for FuseCascadedViewOps
@@ -250,7 +279,6 @@ def test_permute3_view4_chains(self) -> None:
250279

251280
self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2)
252281
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 2)
253-
# Verify order: views before permutes
254282
targets = get_compute_nodes(gm)
255283
view_indices = [
256284
i
@@ -350,7 +378,6 @@ def test_negative_not_squeeze_like(self) -> None:
350378
count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default),
351379
2,
352380
)
353-
# Order unchanged: view, permute, view, permute
354381
targets = get_compute_nodes(result.graph_module)
355382
self.assertEqual(targets[0], exir_ops.edge.aten.view_copy.default)
356383
self.assertEqual(targets[1], exir_ops.edge.aten.permute_copy.default)
@@ -383,6 +410,67 @@ def test_replace_nop_transpose_with_view_float(self) -> None:
383410
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
384411
)
385412

413+
def test_replace_nop_transpose_with_view_int(self) -> None:
414+
x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64)
415+
gm = single_op_builder(
416+
placeholders=(x,),
417+
op=exir_ops.edge.aten.transpose_copy.int,
418+
args=(x, 1, 0),
419+
)
420+
gm_before = copy.deepcopy(gm)
421+
422+
p = ReplaceNopTransposeOrPermuteWithViewPass()
423+
result = cast(PassResult, p(gm))
424+
self.assertTrue(result.modified)
425+
gm_after = result.graph_module
426+
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.transpose_copy.int), 0)
427+
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1)
428+
validate_numerics(
429+
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
430+
)
431+
432+
def test_replace_nop_permute_5d(self) -> None:
433+
x = torch.randn(3, 1, 3, 1, 4)
434+
gm = single_op_builder(
435+
placeholders=(x,),
436+
op=exir_ops.edge.aten.permute_copy.default,
437+
args=(x, [0, 2, 4, 1, 3]),
438+
)
439+
gm_before = copy.deepcopy(gm)
440+
441+
p = ReplaceNopTransposeOrPermuteWithViewPass()
442+
result = cast(PassResult, p(gm))
443+
self.assertTrue(result.modified)
444+
gm_after = result.graph_module
445+
self.assertEqual(
446+
count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0
447+
)
448+
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1)
449+
validate_numerics(
450+
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
451+
)
452+
453+
def test_replace_nop_permute_3d(self) -> None:
454+
x = torch.randn(1, 3, 4)
455+
gm = single_op_builder(
456+
placeholders=(x,),
457+
op=exir_ops.edge.aten.permute_copy.default,
458+
args=(x, [1, 2, 0]),
459+
)
460+
gm_before = copy.deepcopy(gm)
461+
462+
p = ReplaceNopTransposeOrPermuteWithViewPass()
463+
result = cast(PassResult, p(gm))
464+
self.assertTrue(result.modified)
465+
gm_after = result.graph_module
466+
self.assertEqual(
467+
count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0
468+
)
469+
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1)
470+
validate_numerics(
471+
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
472+
)
473+
386474

387475
# ──────────────────────────────────────────────────────────────────────
388476
# Tests for RemovePermutesAroundElementwiseOps cross-view handling
@@ -458,64 +546,3 @@ def test_4d_permute_squeeze_clamp_3d_permute(self) -> None:
458546
gm_before, result.graph_module, [x_data],
459547
"4D_permute_squeeze_clamp_3D_permute",
460548
)
461-
462-
def test_replace_nop_transpose_with_view_int(self) -> None:
463-
x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64)
464-
gm = single_op_builder(
465-
placeholders=(x,),
466-
op=exir_ops.edge.aten.transpose_copy.int,
467-
args=(x, 1, 0),
468-
)
469-
gm_before = copy.deepcopy(gm)
470-
471-
p = ReplaceNopTransposeOrPermuteWithViewPass()
472-
result = cast(PassResult, p(gm))
473-
self.assertTrue(result.modified)
474-
gm_after = result.graph_module
475-
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.transpose_copy.int), 0)
476-
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1)
477-
validate_numerics(
478-
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
479-
)
480-
481-
def test_replace_nop_permute_5d(self) -> None:
482-
x = torch.randn(3, 1, 3, 1, 4)
483-
gm = single_op_builder(
484-
placeholders=(x,),
485-
op=exir_ops.edge.aten.permute_copy.default,
486-
args=(x, [0, 2, 4, 1, 3]),
487-
)
488-
gm_before = copy.deepcopy(gm)
489-
490-
p = ReplaceNopTransposeOrPermuteWithViewPass()
491-
result = cast(PassResult, p(gm))
492-
self.assertTrue(result.modified)
493-
gm_after = result.graph_module
494-
self.assertEqual(
495-
count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0
496-
)
497-
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1)
498-
validate_numerics(
499-
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
500-
)
501-
502-
def test_replace_nop_permute_3d(self) -> None:
503-
x = torch.randn(1, 3, 4)
504-
gm = single_op_builder(
505-
placeholders=(x,),
506-
op=exir_ops.edge.aten.permute_copy.default,
507-
args=(x, [1, 2, 0]),
508-
)
509-
gm_before = copy.deepcopy(gm)
510-
511-
p = ReplaceNopTransposeOrPermuteWithViewPass()
512-
result = cast(PassResult, p(gm))
513-
self.assertTrue(result.modified)
514-
gm_after = result.graph_module
515-
self.assertEqual(
516-
count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0
517-
)
518-
self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1)
519-
validate_numerics(
520-
gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass"
521-
)

0 commit comments

Comments
 (0)