2020class FuseCascadedTransposeOrPermuteOps (RemoveOrReplacePassInterface ):
2121 """
2222 Fuse a chain of transpose and permute ops into a single permute or a no-op.
23- Handles branches and chains permutes.
23+ Handles branches and chains of permutes, including permute-view-permute
24+ patterns where a squeeze/unsqueeze view sits between two permutes.
2425 """
2526
2627 transpose_or_permute_target = {
2728 exir_ops .edge .aten .transpose_copy .int ,
2829 exir_ops .edge .aten .permute_copy .default ,
2930 }
3031
32+ _VIEW_OPS = {
33+ exir_ops .edge .aten .view_copy .default ,
34+ exir_ops .edge .aten .view .default ,
35+ }
36+
3137 @property
3238 def targets (self ) -> list [EdgeOpOverload ]:
3339 return list (self .transpose_or_permute_target )
3440
3541 def maybe_remove_or_replace (self , node : Node ) -> bool :
36- # Fuse with the parent node if it's also a permute or a transpose. Since the
37- # pass interface traverses all ops in order the pass will properly fuse a chain
38- # of permutes.
3942 parent_node = get_arg (node , "input" , Node )
40- if parent_node .target not in self .transpose_or_permute_target :
41- return False
42- input_of_parent = get_arg (parent_node , "input" , Node )
4343
44- # Compute combined effect of permutes.
44+ # Case 1: Direct permute/transpose → permute/transpose
45+ if parent_node .target in self .transpose_or_permute_target :
46+ return self ._fuse_direct (node , parent_node )
47+
48+ # Case 2: permute → view_copy(squeeze/unsqueeze) → permute
49+ if parent_node .target in self ._VIEW_OPS :
50+ return self ._fuse_across_view (node , parent_node )
51+
52+ return False
53+
54+ def _fuse_direct (self , node : Node , parent_node : Node ) -> bool :
55+ """Fuse two adjacent permute/transpose ops."""
56+ input_of_parent = get_arg (parent_node , "input" , Node )
4557 dims = list (range (node .meta ["val" ].ndim ))
4658
4759 if parent_node .target == exir_ops .edge .aten .transpose_copy .int :
@@ -54,7 +66,6 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
5466 else :
5567 dims = get_permuted_dims (node , dims )
5668
57- # If combined effect is identity replace the node with input.
5869 if dims == sorted (dims ):
5970 node .replace_all_uses_with (input_of_parent )
6071 else :
@@ -67,3 +78,97 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
6778 node .replace_all_uses_with (new_permute )
6879
6980 return True
81+
82+ def _fuse_across_view (self , node : Node , view_node : Node ) -> bool :
83+ """Fuse permute -> view(squeeze/unsqueeze) -> permute into a view_copy."""
84+ # view_node must have exactly one user (this permute node)
85+ if len (view_node .users ) != 1 :
86+ return False
87+ # view_node's parent must be a permute/transpose
88+ view_input = get_arg (view_node , "input" , Node )
89+ if view_input .target not in self .transpose_or_permute_target :
90+ return False
91+ # The view must be a squeeze or unsqueeze (rank differs by 1)
92+ if view_node .meta .get ("val" ) is None or view_input .meta .get ("val" ) is None :
93+ return False
94+ view_in_shape = view_input .meta ["val" ].shape
95+ view_out_shape = view_node .meta ["val" ].shape
96+ if abs (len (view_in_shape ) - len (view_out_shape )) != 1 :
97+ return False
98+
99+ # Get the input before the first permute
100+ input_of_first_permute = get_arg (view_input , "input" , Node )
101+ if input_of_first_permute .meta .get ("val" ) is None :
102+ return False
103+
104+ # Compute the combined effect on the original input dimensions
105+ # Start with identity dims for the original input
106+ original_ndim = input_of_first_permute .meta ["val" ].ndim
107+ dims = list (range (original_ndim ))
108+
109+ # Apply first permute
110+ if view_input .target == exir_ops .edge .aten .transpose_copy .int :
111+ dims = get_transposed_dims (view_input , dims )
112+ else :
113+ dims = get_permuted_dims (view_input , dims )
114+
115+ # Apply the view (squeeze/unsqueeze)
116+ if len (view_out_shape ) == len (view_in_shape ) + 1 :
117+ # unsqueeze: insert a new dim
118+ index = self ._find_extra_one (view_out_shape , view_in_shape )
119+ if index == - 1 :
120+ return False
121+ dims = [x + 1 if x >= index else x for x in dims ]
122+ dims .insert (index , - 1 ) # -1 marks the inserted dim
123+ elif len (view_in_shape ) == len (view_out_shape ) + 1 :
124+ # squeeze: remove a dim
125+ index = self ._find_extra_one (view_in_shape , view_out_shape )
126+ if index == - 1 :
127+ return False
128+ if dims [index ] != - 1 :
129+ # Safe: permutation preserves dimension sizes, so a size-1
130+ # intermediate dim necessarily originated from a size-1 input dim.
131+ pass
132+ del dims [index ]
133+
134+ # Apply second permute (node)
135+ if node .target == exir_ops .edge .aten .transpose_copy .int :
136+ node_dims = list (range (len (dims )))
137+ node_dims = get_transposed_dims (node , node_dims )
138+ dims = [dims [d ] for d in node_dims ]
139+ else :
140+ perm = get_arg (node , "dims" )
141+ dims = [dims [d ] for d in perm ]
142+
143+ # Check if the combined effect (ignoring -1 inserted dims) is identity
144+ real_dims = [d for d in dims if d != - 1 ]
145+
146+ if real_dims == sorted (real_dims ):
147+ # Combined permutations are identity — replace with view_copy
148+ # (the only remaining effect is the squeeze/unsqueeze reshape)
149+ output_shape = node .meta ["val" ].shape
150+ if output_shape == input_of_first_permute .meta ["val" ].shape :
151+ # Total no-op: replace with input
152+ node .replace_all_uses_with (input_of_first_permute )
153+ else :
154+ with node .graph .inserting_before (node ):
155+ new_view = node .graph .call_function (
156+ exir_ops .edge .aten .view_copy .default ,
157+ args = (input_of_first_permute , list (output_shape )),
158+ )
159+ new_view .meta = node .meta
160+ node .replace_all_uses_with (new_view )
161+ return True
162+
163+ return False
164+
165+ @staticmethod
166+ def _find_extra_one (longer , shorter ):
167+ if len (longer ) != len (shorter ) + 1 :
168+ return - 1
169+ for i in range (len (shorter )):
170+ if longer [i ] != shorter [i ]:
171+ if longer [i ] == 1 and shorter [i :] == longer [i + 1 :]:
172+ return i
173+ return - 1
174+ return len (shorter ) if longer [- 1 ] == 1 else - 1
0 commit comments