1919
2020import torch
2121import torch .fx
22- from executorch .backends .cadence .aot .compiler_utils import quantize_tensor_multiplier
22+ from executorch .backends .cadence .aot .compiler_utils import (
23+ quantize_tensor_multiplier ,
24+ transpose_dims_to_permute_order ,
25+ )
2326from executorch .backends .cadence .aot .fuse_ops import FuseCascadedTransposeOrPermuteOps
2427from executorch .backends .cadence .aot .pass_utils import (
2528 CadencePassAttribute ,
@@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
355358
356359 # Handle transpose: if mat2 is a transpose op, extract the original tensor
357360 transposed_mat2 = False
358- if (
359- mat2 .op == "call_function"
360- and mat2 .target == exir_ops .edge .aten .transpose_copy . int
361+ if mat2 . op == "call_function" and (
362+ mat2 .target == exir_ops . edge . aten . transpose_copy . int
363+ or mat2 .target == exir_ops .edge .aten .permute_copy . default
361364 ):
362365 # mat2 is already transposed, so we use the input to the transpose
363366 mat2 = cast (torch .fx .Node , mat2 .args [0 ])
@@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
405408 # Transpose mat2 if it wasn't already transposed
406409 if not transposed_mat2 :
407410 with graph .inserting_before (node ):
411+ ndim = len (mat2 .meta ["val" ].shape )
412+ perm = transpose_dims_to_permute_order (ndim , - 1 , - 2 )
408413 mat2 = graph .call_function (
409- exir_ops .edge .aten .transpose_copy . int ,
410- args = (mat2 , - 1 , - 2 ),
414+ exir_ops .edge .aten .permute_copy . default ,
415+ args = (mat2 , perm ),
411416 )
412417
413418 # Metadata copy important
@@ -430,6 +435,35 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
430435 return True
431436
432437
438+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
439+ class ReplaceTransposeWithPermutePass (RemoveOrReplacePassInterface ):
440+ """
441+ Replace transpose_copy.int ops with equivalent permute_copy.default ops
442+ to canonicalize on permute as the single layout-change op.
443+ """
444+
445+ @property
446+ def targets (self ) -> list [EdgeOpOverload ]:
447+ return [exir_ops .edge .aten .transpose_copy .int ]
448+
449+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
450+ in_tensor = node .args [0 ]
451+ assert isinstance (in_tensor , torch .fx .Node )
452+ ndim = len (in_tensor .meta ["val" ].shape )
453+ dim0 = cast (int , node .args [1 ])
454+ dim1 = cast (int , node .args [2 ])
455+ perm = transpose_dims_to_permute_order (ndim , dim0 , dim1 )
456+
457+ with node .graph .inserting_before (node ):
458+ new_node = node .graph .call_function (
459+ exir_ops .edge .aten .permute_copy .default ,
460+ args = (in_tensor , perm ),
461+ )
462+ new_node .meta = node .meta
463+ node .replace_all_uses_with (new_node )
464+ return True
465+
466+
433467@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
434468class ReplacePermuteWithTransposePass (RemoveOrReplacePassInterface ):
435469 """
@@ -471,7 +505,6 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
471505
472506 return False
473507
474-
475508@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
476509class ReplaceConvolutionOptionalArgsWithConcreteArgsPass (RemoveOrReplacePassInterface ):
477510 """
@@ -798,9 +831,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
798831 # gather stencil. Also, the first two dimensions of weight must be
799832 # transposed/interchanged.
800833 assert isinstance (weight , torch .fx .Node )
834+ weight_ndim = len (weight .meta ["val" ].shape )
835+ perm = transpose_dims_to_permute_order (weight_ndim , 0 , 1 )
801836 transposed_weight = node .graph .call_function (
802- exir_ops .edge .aten .transpose_copy . int ,
803- args = (weight , 0 , 1 ),
837+ exir_ops .edge .aten .permute_copy . default ,
838+ args = (weight , perm ),
804839 )
805840 transposed_weight .meta = weight .meta
806841
@@ -1036,18 +1071,19 @@ def targets(self) -> list[EdgeOpOverload]:
10361071 def _transpose_dims (
10371072 self , graph : torch .fx .Graph , node : torch .fx .Node , dim0 : int , dim1 : int
10381073 ) -> torch .fx .Node :
1039- """Helper function to transpose dims of a node."""
1074+ """Helper function to transpose dims of a node using permute ."""
10401075 shape = node .meta ["val" ].shape
10411076 dim0 , dim1 = (
10421077 canonicalize_transposed_dim (dim0 , shape ),
10431078 canonicalize_transposed_dim (dim1 , shape ),
10441079 )
10451080 dim0 , dim1 = min (dim0 , dim1 ), max (dim0 , dim1 )
1046- transpose_node = graph .call_function (
1047- exir_ops .edge .aten .transpose_copy .int , (node , dim0 , dim1 ), {}
1081+ perm = transpose_dims_to_permute_order (len (shape ), dim0 , dim1 )
1082+ permute_node = graph .call_function (
1083+ exir_ops .edge .aten .permute_copy .default , (node , perm ), {}
10481084 )
1049- transpose_node .meta = node .meta
1050- return transpose_node
1085+ permute_node .meta = node .meta
1086+ return permute_node
10511087
10521088 def _change_nchw_to_nhwc (
10531089 self , graph : torch .fx .Graph , node : torch .fx .Node
@@ -1263,18 +1299,19 @@ def targets(self) -> list[EdgeOpOverload]:
12631299 def _transpose_dims (
12641300 self , graph : torch .fx .Graph , node : torch .fx .Node , dim0 : int , dim1 : int
12651301 ) -> torch .fx .Node :
1266- """Helper function to transpose dims of a node."""
1302+ """Helper function to transpose dims of a node using permute ."""
12671303 shape = node .meta ["val" ].shape
12681304 dim0 , dim1 = (
12691305 canonicalize_transposed_dim (dim0 , shape ),
12701306 canonicalize_transposed_dim (dim1 , shape ),
12711307 )
12721308 dim0 , dim1 = min (dim0 , dim1 ), max (dim0 , dim1 )
1273- transpose_node = graph .call_function (
1274- exir_ops .edge .aten .transpose_copy .int , (node , dim0 , dim1 ), {}
1309+ perm = transpose_dims_to_permute_order (len (shape ), dim0 , dim1 )
1310+ permute_node = graph .call_function (
1311+ exir_ops .edge .aten .permute_copy .default , (node , perm ), {}
12751312 )
1276- transpose_node .meta = node .meta
1277- return transpose_node
1313+ permute_node .meta = node .meta
1314+ return permute_node
12781315
12791316 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
12801317 # Get the dimension argument
@@ -1526,8 +1563,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15261563 if not channel_last :
15271564 with graph .inserting_before (node ):
15281565 linear_res = graph .call_function (
1529- exir_ops .edge .aten .transpose_copy . int ,
1530- args = (linear_res , 1 , 2 ),
1566+ exir_ops .edge .aten .permute_copy . default ,
1567+ args = (linear_res , [ 0 , 2 , 1 ] ),
15311568 )
15321569 linear_res .meta = node .meta
15331570
@@ -1717,8 +1754,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
17171754 if not channel_last :
17181755 with graph .inserting_before (node ):
17191756 linear_res = graph .call_function (
1720- exir_ops .edge .aten .transpose_copy . int ,
1721- args = (linear_res , 1 , 2 ),
1757+ exir_ops .edge .aten .permute_copy . default ,
1758+ args = (linear_res , [ 0 , 2 , 1 ] ),
17221759 )
17231760 linear_res .meta = node .meta
17241761
@@ -2375,9 +2412,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
23752412
23762413 # Transpose Y_arg
23772414 with graph .inserting_before (node ):
2415+ Y_ndim = len (Y_tensor_val .shape )
2416+ perm = transpose_dims_to_permute_order (Y_ndim , - 1 , - 2 )
23782417 Y_arg_t = graph .call_function (
2379- exir_ops .edge .aten .transpose_copy . int ,
2380- args = (Y_arg , - 1 , - 2 ),
2418+ exir_ops .edge .aten .permute_copy . default ,
2419+ args = (Y_arg , perm ),
23812420 )
23822421 Y_arg_t .meta = node .meta
23832422
@@ -2410,13 +2449,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24102449 result = super ().call (graph_module )
24112450 modified = modified or result .modified
24122451 if modified :
2413- # Fuse any inserted transpose node with transpose/permute nodes
2452+ # Fuse any inserted permute node with transpose/permute nodes
24142453 # surrounding it.
24152454 result = FuseCascadedTransposeOrPermuteOps ().call (result .graph_module )
24162455 modified = modified or result .modified
2417- # Replace permute with transpose.
2418- result = ReplacePermuteWithTransposePass ().call (result .graph_module )
2419- modified = modified or result .modified
24202456
24212457 return PassResult (result .graph_module , modified )
24222458
@@ -2640,10 +2676,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
26402676# graph with another.
26412677class CadenceReplaceOpsInGraph :
26422678 passes = CommonReplacePasses .passes + [
2679+ ReplaceTransposeWithPermutePass ,
26432680 ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass ,
26442681 ReplaceEmptyTensorsWithFullPass ,
26452682 ReplaceFunctionallyEquivalentOpTargets ,
2646- ReplacePermuteWithTransposePass ,
26472683 ReplaceConvolutionOptionalArgsWithConcreteArgsPass ,
26482684 ReplaceAddMMWithLinearPass ,
26492685 ReplacePadWithCatPass ,
@@ -2657,6 +2693,8 @@ class CadenceReplaceOpsInGraph:
26572693 ReplaceIm2RowWithViewPass ,
26582694 MakeSliceAndCatDimOutermostPass ,
26592695 ReplaceMatmulWithTransposedMatmulPass ,
2696+ # Convert permutes back to transposes after all passes that create them.
2697+ ReplacePermuteWithTransposePass ,
26602698 ReplaceNopTransposeOrPermuteWithViewPass ,
26612699 ReplaceLinearWithFullyConnectedOpPass ,
26622700 ReplaceScalarTensorWithFullPass ,
0 commit comments