1919
2020import torch
2121import torch .fx
22- from executorch .backends .cadence .aot .compiler_utils import quantize_tensor_multiplier
22+ from executorch .backends .cadence .aot .compiler_utils import (
23+ quantize_tensor_multiplier ,
24+ transpose_dims_to_permute_order ,
25+ )
2326from executorch .backends .cadence .aot .fuse_ops import FuseCascadedTransposeOrPermuteOps
2427from executorch .backends .cadence .aot .pass_utils import (
2528 CadencePassAttribute ,
@@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
355358
356359 # Handle transpose: if mat2 is a transpose op, extract the original tensor
357360 transposed_mat2 = False
358- if (
359- mat2 .op == "call_function"
360- and mat2 .target == exir_ops .edge .aten .transpose_copy . int
361+ if mat2 . op == "call_function" and (
362+ mat2 .target == exir_ops . edge . aten . transpose_copy . int
363+ or mat2 .target == exir_ops .edge .aten .permute_copy . default
361364 ):
362365 # mat2 is already transposed, so we use the input to the transpose
363366 mat2 = cast (torch .fx .Node , mat2 .args [0 ])
@@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
405408 # Transpose mat2 if it wasn't already transposed
406409 if not transposed_mat2 :
407410 with graph .inserting_before (node ):
411+ ndim = len (mat2 .meta ["val" ].shape )
412+ perm = transpose_dims_to_permute_order (ndim , - 1 , - 2 )
408413 mat2 = graph .call_function (
409- exir_ops .edge .aten .transpose_copy . int ,
410- args = (mat2 , - 1 , - 2 ),
414+ exir_ops .edge .aten .permute_copy . default ,
415+ args = (mat2 , perm ),
411416 )
412417
413418 # Metadata copy important
@@ -430,6 +435,35 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
430435 return True
431436
432437
438+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
439+ class ReplaceTransposeWithPermutePass (RemoveOrReplacePassInterface ):
440+ """
441+ Replace transpose_copy.int ops with equivalent permute_copy.default ops
442+ to canonicalize on permute as the single layout-change op.
443+ """
444+
445+ @property
446+ def targets (self ) -> list [EdgeOpOverload ]:
447+ return [exir_ops .edge .aten .transpose_copy .int ]
448+
449+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
450+ in_tensor = node .args [0 ]
451+ assert isinstance (in_tensor , torch .fx .Node )
452+ ndim = len (in_tensor .meta ["val" ].shape )
453+ dim0 = cast (int , node .args [1 ])
454+ dim1 = cast (int , node .args [2 ])
455+ perm = transpose_dims_to_permute_order (ndim , dim0 , dim1 )
456+
457+ with node .graph .inserting_before (node ):
458+ new_node = node .graph .call_function (
459+ exir_ops .edge .aten .permute_copy .default ,
460+ args = (in_tensor , perm ),
461+ )
462+ new_node .meta = node .meta
463+ node .replace_all_uses_with (new_node )
464+ return True
465+
466+
433467@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
434468class ReplacePermuteWithTransposePass (RemoveOrReplacePassInterface ):
435469 """
@@ -798,9 +832,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
798832 # gather stencil. Also, the first two dimensions of weight must be
799833 # transposed/interchanged.
800834 assert isinstance (weight , torch .fx .Node )
835+ weight_ndim = len (weight .meta ["val" ].shape )
836+ perm = transpose_dims_to_permute_order (weight_ndim , 0 , 1 )
801837 transposed_weight = node .graph .call_function (
802- exir_ops .edge .aten .transpose_copy . int ,
803- args = (weight , 0 , 1 ),
838+ exir_ops .edge .aten .permute_copy . default ,
839+ args = (weight , perm ),
804840 )
805841 transposed_weight .meta = weight .meta
806842
@@ -1037,18 +1073,19 @@ def targets(self) -> list[EdgeOpOverload]:
10371073 def _transpose_dims (
10381074 self , graph : torch .fx .Graph , node : torch .fx .Node , dim0 : int , dim1 : int
10391075 ) -> torch .fx .Node :
1040- """Helper function to transpose dims of a node."""
1076+ """Helper function to transpose dims of a node using permute ."""
10411077 shape = node .meta ["val" ].shape
10421078 dim0 , dim1 = (
10431079 canonicalize_transposed_dim (dim0 , shape ),
10441080 canonicalize_transposed_dim (dim1 , shape ),
10451081 )
10461082 dim0 , dim1 = min (dim0 , dim1 ), max (dim0 , dim1 )
1047- transpose_node = graph .call_function (
1048- exir_ops .edge .aten .transpose_copy .int , (node , dim0 , dim1 ), {}
1083+ perm = transpose_dims_to_permute_order (len (shape ), dim0 , dim1 )
1084+ permute_node = graph .call_function (
1085+ exir_ops .edge .aten .permute_copy .default , (node , perm ), {}
10491086 )
1050- transpose_node .meta = node .meta
1051- return transpose_node
1087+ permute_node .meta = node .meta
1088+ return permute_node
10521089
10531090 def _change_nchw_to_nhwc (
10541091 self , graph : torch .fx .Graph , node : torch .fx .Node
@@ -1273,18 +1310,19 @@ def targets(self) -> list[EdgeOpOverload]:
12731310 def _transpose_dims (
12741311 self , graph : torch .fx .Graph , node : torch .fx .Node , dim0 : int , dim1 : int
12751312 ) -> torch .fx .Node :
1276- """Helper function to transpose dims of a node."""
1313+ """Helper function to transpose dims of a node using permute ."""
12771314 shape = node .meta ["val" ].shape
12781315 dim0 , dim1 = (
12791316 canonicalize_transposed_dim (dim0 , shape ),
12801317 canonicalize_transposed_dim (dim1 , shape ),
12811318 )
12821319 dim0 , dim1 = min (dim0 , dim1 ), max (dim0 , dim1 )
1283- transpose_node = graph .call_function (
1284- exir_ops .edge .aten .transpose_copy .int , (node , dim0 , dim1 ), {}
1320+ perm = transpose_dims_to_permute_order (len (shape ), dim0 , dim1 )
1321+ permute_node = graph .call_function (
1322+ exir_ops .edge .aten .permute_copy .default , (node , perm ), {}
12851323 )
1286- transpose_node .meta = node .meta
1287- return transpose_node
1324+ permute_node .meta = node .meta
1325+ return permute_node
12881326
12891327 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
12901328 # Get the dimension argument
@@ -1536,8 +1574,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15361574 if not channel_last :
15371575 with graph .inserting_before (node ):
15381576 linear_res = graph .call_function (
1539- exir_ops .edge .aten .transpose_copy . int ,
1540- args = (linear_res , 1 , 2 ),
1577+ exir_ops .edge .aten .permute_copy . default ,
1578+ args = (linear_res , [ 0 , 2 , 1 ] ),
15411579 )
15421580 linear_res .meta = node .meta
15431581
@@ -1727,8 +1765,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
17271765 if not channel_last :
17281766 with graph .inserting_before (node ):
17291767 linear_res = graph .call_function (
1730- exir_ops .edge .aten .transpose_copy . int ,
1731- args = (linear_res , 1 , 2 ),
1768+ exir_ops .edge .aten .permute_copy . default ,
1769+ args = (linear_res , [ 0 , 2 , 1 ] ),
17321770 )
17331771 linear_res .meta = node .meta
17341772
@@ -2391,9 +2429,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
23912429
23922430 # Transpose Y_arg
23932431 with graph .inserting_before (node ):
2432+ Y_ndim = len (Y_tensor_val .shape )
2433+ perm = transpose_dims_to_permute_order (Y_ndim , - 1 , - 2 )
23942434 Y_arg_t = graph .call_function (
2395- exir_ops .edge .aten .transpose_copy . int ,
2396- args = (Y_arg , - 1 , - 2 ),
2435+ exir_ops .edge .aten .permute_copy . default ,
2436+ args = (Y_arg , perm ),
23972437 )
23982438 Y_arg_t .meta = node .meta
23992439
@@ -2426,13 +2466,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24262466 result = super ().call (graph_module )
24272467 modified = modified or result .modified
24282468 if modified :
2429- # Fuse any inserted transpose node with transpose/permute nodes
2469+ # Fuse any inserted permute node with transpose/permute nodes
24302470 # surrounding it.
24312471 result = FuseCascadedTransposeOrPermuteOps ().call (result .graph_module )
24322472 modified = modified or result .modified
2433- # Replace permute with transpose.
2434- result = ReplacePermuteWithTransposePass ().call (result .graph_module )
2435- modified = modified or result .modified
24362473
24372474 return PassResult (result .graph_module , modified )
24382475
@@ -2656,10 +2693,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
26562693# graph with another.
26572694class CadenceReplaceOpsInGraph :
26582695 passes = CommonReplacePasses .passes + [
2696+ ReplaceTransposeWithPermutePass ,
26592697 ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass ,
26602698 ReplaceEmptyTensorsWithFullPass ,
26612699 ReplaceFunctionallyEquivalentOpTargets ,
2662- ReplacePermuteWithTransposePass ,
26632700 ReplaceConvolutionOptionalArgsWithConcreteArgsPass ,
26642701 ReplaceAddMMWithLinearPass ,
26652702 ReplacePadWithCatPass ,
@@ -2673,6 +2710,8 @@ class CadenceReplaceOpsInGraph:
26732710 ReplaceIm2RowWithViewPass ,
26742711 MakeSliceAndCatDimOutermostPass ,
26752712 ReplaceMatmulWithTransposedMatmulPass ,
2713+ # Convert permutes back to transposes after all passes that create them.
2714+ ReplacePermuteWithTransposePass ,
26762715 ReplaceNopTransposeOrPermuteWithViewPass ,
26772716 ReplaceLinearWithFullyConnectedOpPass ,
26782717 ReplaceScalarTensorWithFullPass ,
0 commit comments