1919
2020import torch
2121import torch .fx
22- from executorch .backends .cadence .aot .compiler_utils import quantize_tensor_multiplier
22+ from executorch .backends .cadence .aot .compiler_utils import (
23+ quantize_tensor_multiplier ,
24+ transpose_dims_to_permute_order ,
25+ )
2326from executorch .backends .cadence .aot .fuse_ops import FuseCascadedTransposeOrPermuteOps
2427from executorch .backends .cadence .aot .pass_utils import (
2528 CadencePassAttribute ,
@@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
355358
356359 # Handle transpose: if mat2 is a transpose op, extract the original tensor
357360 transposed_mat2 = False
358- if (
359- mat2 .op == "call_function"
360- and mat2 .target == exir_ops .edge .aten .transpose_copy . int
361+ if mat2 . op == "call_function" and (
362+ mat2 .target == exir_ops . edge . aten . transpose_copy . int
363+ or mat2 .target == exir_ops .edge .aten .permute_copy . default
361364 ):
362365 # mat2 is already transposed, so we use the input to the transpose
363366 mat2 = cast (torch .fx .Node , mat2 .args [0 ])
@@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
405408 # Transpose mat2 if it wasn't already transposed
406409 if not transposed_mat2 :
407410 with graph .inserting_before (node ):
411+ ndim = len (mat2 .meta ["val" ].shape )
412+ perm = transpose_dims_to_permute_order (ndim , - 1 , - 2 )
408413 mat2 = graph .call_function (
409- exir_ops .edge .aten .transpose_copy . int ,
410- args = (mat2 , - 1 , - 2 ),
414+ exir_ops .edge .aten .permute_copy . default ,
415+ args = (mat2 , perm ),
411416 )
412417
413418 # Metadata copy important
@@ -430,47 +435,6 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
430435 return True
431436
432437
433- @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
434- class ReplacePermuteWithTransposePass (RemoveOrReplacePassInterface ):
435- """
436- Replace permute op with transpose if the permutation is only along
437- two dimensions.
438- """
439-
440- @property
441- def targets (self ) -> list [EdgeOpOverload ]:
442- return [exir_ops .edge .aten .permute_copy .default ]
443-
444- def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
445- # Get the old dim and new dim order
446- in_tensor = node .args [0 ]
447- assert isinstance (in_tensor , torch .fx .Node )
448- in_shape = in_tensor .meta ["val" ].shape
449- old_dims = tuple (range (len (in_shape )))
450- new_dims = cast (Sequence [int ], node .args [1 ])
451-
452- # Compute the number of positions in which the old and new order differ
453- diff = [od for od , nd in zip (old_dims , new_dims ) if od != nd ]
454-
455- # If the difference is zero, replace with identity (just the input)
456- if len (diff ) == 0 :
457- node .replace_all_uses_with (in_tensor )
458- return True
459-
460- # If the difference is in two dimensions, we can replace this permute op
461- # with transpose op.
462- if len (diff ) == 2 :
463- with node .graph .inserting_before (node ):
464- new_node = node .graph .call_function (
465- exir_ops .edge .aten .transpose_copy .int ,
466- args = (node .args [0 ], diff [0 ], diff [1 ]),
467- )
468- new_node .meta = node .meta
469- node .replace_all_uses_with (new_node )
470- return True
471-
472- return False
473-
474438
475439@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
476440class ReplaceConvolutionOptionalArgsWithConcreteArgsPass (RemoveOrReplacePassInterface ):
@@ -798,9 +762,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
798762 # gather stencil. Also, the first two dimensions of weight must be
799763 # transposed/interchanged.
800764 assert isinstance (weight , torch .fx .Node )
765+ weight_ndim = len (weight .meta ["val" ].shape )
766+ perm = transpose_dims_to_permute_order (weight_ndim , 0 , 1 )
801767 transposed_weight = node .graph .call_function (
802- exir_ops .edge .aten .transpose_copy . int ,
803- args = (weight , 0 , 1 ),
768+ exir_ops .edge .aten .permute_copy . default ,
769+ args = (weight , perm ),
804770 )
805771 transposed_weight .meta = weight .meta
806772
@@ -1036,18 +1002,19 @@ def targets(self) -> list[EdgeOpOverload]:
10361002 def _transpose_dims (
10371003 self , graph : torch .fx .Graph , node : torch .fx .Node , dim0 : int , dim1 : int
10381004 ) -> torch .fx .Node :
1039- """Helper function to transpose dims of a node."""
1005+ """Helper function to transpose dims of a node using permute ."""
10401006 shape = node .meta ["val" ].shape
10411007 dim0 , dim1 = (
10421008 canonicalize_transposed_dim (dim0 , shape ),
10431009 canonicalize_transposed_dim (dim1 , shape ),
10441010 )
10451011 dim0 , dim1 = min (dim0 , dim1 ), max (dim0 , dim1 )
1046- transpose_node = graph .call_function (
1047- exir_ops .edge .aten .transpose_copy .int , (node , dim0 , dim1 ), {}
1012+ perm = transpose_dims_to_permute_order (len (shape ), dim0 , dim1 )
1013+ permute_node = graph .call_function (
1014+ exir_ops .edge .aten .permute_copy .default , (node , perm ), {}
10481015 )
1049- transpose_node .meta = node .meta
1050- return transpose_node
1016+ permute_node .meta = node .meta
1017+ return permute_node
10511018
10521019 def _change_nchw_to_nhwc (
10531020 self , graph : torch .fx .Graph , node : torch .fx .Node
@@ -1263,18 +1230,19 @@ def targets(self) -> list[EdgeOpOverload]:
12631230 def _transpose_dims (
12641231 self , graph : torch .fx .Graph , node : torch .fx .Node , dim0 : int , dim1 : int
12651232 ) -> torch .fx .Node :
1266- """Helper function to transpose dims of a node."""
1233+ """Helper function to transpose dims of a node using permute ."""
12671234 shape = node .meta ["val" ].shape
12681235 dim0 , dim1 = (
12691236 canonicalize_transposed_dim (dim0 , shape ),
12701237 canonicalize_transposed_dim (dim1 , shape ),
12711238 )
12721239 dim0 , dim1 = min (dim0 , dim1 ), max (dim0 , dim1 )
1273- transpose_node = graph .call_function (
1274- exir_ops .edge .aten .transpose_copy .int , (node , dim0 , dim1 ), {}
1240+ perm = transpose_dims_to_permute_order (len (shape ), dim0 , dim1 )
1241+ permute_node = graph .call_function (
1242+ exir_ops .edge .aten .permute_copy .default , (node , perm ), {}
12751243 )
1276- transpose_node .meta = node .meta
1277- return transpose_node
1244+ permute_node .meta = node .meta
1245+ return permute_node
12781246
12791247 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
12801248 # Get the dimension argument
@@ -1526,8 +1494,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15261494 if not channel_last :
15271495 with graph .inserting_before (node ):
15281496 linear_res = graph .call_function (
1529- exir_ops .edge .aten .transpose_copy . int ,
1530- args = (linear_res , 1 , 2 ),
1497+ exir_ops .edge .aten .permute_copy . default ,
1498+ args = (linear_res , [ 0 , 2 , 1 ] ),
15311499 )
15321500 linear_res .meta = node .meta
15331501
@@ -1717,8 +1685,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
17171685 if not channel_last :
17181686 with graph .inserting_before (node ):
17191687 linear_res = graph .call_function (
1720- exir_ops .edge .aten .transpose_copy . int ,
1721- args = (linear_res , 1 , 2 ),
1688+ exir_ops .edge .aten .permute_copy . default ,
1689+ args = (linear_res , [ 0 , 2 , 1 ] ),
17221690 )
17231691 linear_res .meta = node .meta
17241692
@@ -2375,9 +2343,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
23752343
23762344 # Transpose Y_arg
23772345 with graph .inserting_before (node ):
2346+ Y_ndim = len (Y_tensor_val .shape )
2347+ perm = transpose_dims_to_permute_order (Y_ndim , - 1 , - 2 )
23782348 Y_arg_t = graph .call_function (
2379- exir_ops .edge .aten .transpose_copy . int ,
2380- args = (Y_arg , - 1 , - 2 ),
2349+ exir_ops .edge .aten .permute_copy . default ,
2350+ args = (Y_arg , perm ),
23812351 )
23822352 Y_arg_t .meta = node .meta
23832353
@@ -2410,13 +2380,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24102380 result = super ().call (graph_module )
24112381 modified = modified or result .modified
24122382 if modified :
2413- # Fuse any inserted transpose node with transpose/permute nodes
2383+ # Fuse any inserted permute node with transpose/permute nodes
24142384 # surrounding it.
24152385 result = FuseCascadedTransposeOrPermuteOps ().call (result .graph_module )
24162386 modified = modified or result .modified
2417- # Replace permute with transpose.
2418- result = ReplacePermuteWithTransposePass ().call (result .graph_module )
2419- modified = modified or result .modified
24202387
24212388 return PassResult (result .graph_module , modified )
24222389
@@ -2643,7 +2610,6 @@ class CadenceReplaceOpsInGraph:
26432610 ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass ,
26442611 ReplaceEmptyTensorsWithFullPass ,
26452612 ReplaceFunctionallyEquivalentOpTargets ,
2646- ReplacePermuteWithTransposePass ,
26472613 ReplaceConvolutionOptionalArgsWithConcreteArgsPass ,
26482614 ReplaceAddMMWithLinearPass ,
26492615 ReplacePadWithCatPass ,
0 commit comments