Skip to content

Commit d6c1bc4

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Canonicalize on permutes (#18822)
Summary: First step toward canonicalizing on permute in the graph compiler passes. Reviewed By: ethansfng Differential Revision: D100379751
1 parent 875f7c8 commit d6c1bc4

4 files changed

Lines changed: 138 additions & 154 deletions

File tree

backends/cadence/aot/compiler_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
8787
)
8888

8989

90+
def transpose_dims_to_permute_order(ndim: int, dim0: int, dim1: int) -> List[int]:
91+
"""
92+
Convert transpose(dim0, dim1) to an equivalent permute order list.
93+
E.g., transpose(0, 1) on a 3D tensor gives [1, 0, 2].
94+
"""
95+
# Normalize negative dims
96+
if dim0 < 0:
97+
dim0 += ndim
98+
if dim1 < 0:
99+
dim1 += ndim
100+
order = list(range(ndim))
101+
order[dim0], order[dim1] = order[dim1], order[dim0]
102+
return order
103+
104+
90105
def get_transposed_dims(
91106
node: torch.fx.Node, dims: Optional[List[int]] = None
92107
) -> List[int]:

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
678678
quant_node,
679679
)
680680
elif isinstance(pattern, AddmmPattern):
681-
# Transpose the weight tensor
681+
# Transpose the weight tensor using permute
682+
weight_ndim = len(weights_inputs[0].meta["val"].shape)
683+
perm = list(range(weight_ndim))
684+
perm[0], perm[1] = perm[1], perm[0]
682685
transposed_weights = graph_module.graph.call_function(
683-
torch.ops.aten.transpose.int,
684-
(weights_inputs[0], 0, 1),
686+
torch.ops.aten.permute.default,
687+
(weights_inputs[0], perm),
685688
)
686689
assert (
687690
"val" in weights_inputs[0].meta
@@ -692,7 +695,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
692695
), "fake_mode is None on weight node"
693696
with original_val.fake_mode:
694697
transposed_weights.meta["val"] = (
695-
torch.ops.aten.transpose.int(original_val, 0, 1)
698+
torch.ops.aten.permute.default(original_val, perm)
696699
)
697700
copy_node_metadata(transposed_weights, weights_inputs[0])
698701

backends/cadence/aot/replace_ops.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
import torch
2121
import torch.fx
22-
from executorch.backends.cadence.aot.compiler_utils import quantize_tensor_multiplier
22+
from executorch.backends.cadence.aot.compiler_utils import (
23+
quantize_tensor_multiplier,
24+
transpose_dims_to_permute_order,
25+
)
2326
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedTransposeOrPermuteOps
2427
from executorch.backends.cadence.aot.pass_utils import (
2528
CadencePassAttribute,
@@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
355358

356359
# Handle transpose: if mat2 is a transpose op, extract the original tensor
357360
transposed_mat2 = False
358-
if (
359-
mat2.op == "call_function"
360-
and mat2.target == exir_ops.edge.aten.transpose_copy.int
361+
if mat2.op == "call_function" and (
362+
mat2.target == exir_ops.edge.aten.transpose_copy.int
363+
or mat2.target == exir_ops.edge.aten.permute_copy.default
361364
):
362365
# mat2 is already transposed, so we use the input to the transpose
363366
mat2 = cast(torch.fx.Node, mat2.args[0])
@@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
405408
# Transpose mat2 if it wasn't already transposed
406409
if not transposed_mat2:
407410
with graph.inserting_before(node):
411+
ndim = len(mat2.meta["val"].shape)
412+
perm = transpose_dims_to_permute_order(ndim, -1, -2)
408413
mat2 = graph.call_function(
409-
exir_ops.edge.aten.transpose_copy.int,
410-
args=(mat2, -1, -2),
414+
exir_ops.edge.aten.permute_copy.default,
415+
args=(mat2, perm),
411416
)
412417

413418
# Metadata copy important
@@ -430,6 +435,35 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
430435
return True
431436

432437

438+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
439+
class ReplaceTransposeWithPermutePass(RemoveOrReplacePassInterface):
440+
"""
441+
Replace transpose_copy.int ops with equivalent permute_copy.default ops
442+
to canonicalize on permute as the single layout-change op.
443+
"""
444+
445+
@property
446+
def targets(self) -> list[EdgeOpOverload]:
447+
return [exir_ops.edge.aten.transpose_copy.int]
448+
449+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
450+
in_tensor = node.args[0]
451+
assert isinstance(in_tensor, torch.fx.Node)
452+
ndim = len(in_tensor.meta["val"].shape)
453+
dim0 = cast(int, node.args[1])
454+
dim1 = cast(int, node.args[2])
455+
perm = transpose_dims_to_permute_order(ndim, dim0, dim1)
456+
457+
with node.graph.inserting_before(node):
458+
new_node = node.graph.call_function(
459+
exir_ops.edge.aten.permute_copy.default,
460+
args=(in_tensor, perm),
461+
)
462+
new_node.meta = node.meta
463+
node.replace_all_uses_with(new_node)
464+
return True
465+
466+
433467
@register_cadence_pass(CadencePassAttribute(opt_level=1))
434468
class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface):
435469
"""
@@ -471,7 +505,6 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
471505

472506
return False
473507

474-
475508
@register_cadence_pass(CadencePassAttribute(opt_level=0))
476509
class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface):
477510
"""
@@ -798,9 +831,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
798831
# gather stencil. Also, the first two dimensions of weight must be
799832
# transposed/interchanged.
800833
assert isinstance(weight, torch.fx.Node)
834+
weight_ndim = len(weight.meta["val"].shape)
835+
perm = transpose_dims_to_permute_order(weight_ndim, 0, 1)
801836
transposed_weight = node.graph.call_function(
802-
exir_ops.edge.aten.transpose_copy.int,
803-
args=(weight, 0, 1),
837+
exir_ops.edge.aten.permute_copy.default,
838+
args=(weight, perm),
804839
)
805840
transposed_weight.meta = weight.meta
806841

@@ -1036,18 +1071,19 @@ def targets(self) -> list[EdgeOpOverload]:
10361071
def _transpose_dims(
10371072
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
10381073
) -> torch.fx.Node:
1039-
"""Helper function to transpose dims of a node."""
1074+
"""Helper function to transpose dims of a node using permute."""
10401075
shape = node.meta["val"].shape
10411076
dim0, dim1 = (
10421077
canonicalize_transposed_dim(dim0, shape),
10431078
canonicalize_transposed_dim(dim1, shape),
10441079
)
10451080
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1046-
transpose_node = graph.call_function(
1047-
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1081+
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
1082+
permute_node = graph.call_function(
1083+
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
10481084
)
1049-
transpose_node.meta = node.meta
1050-
return transpose_node
1085+
permute_node.meta = node.meta
1086+
return permute_node
10511087

10521088
def _change_nchw_to_nhwc(
10531089
self, graph: torch.fx.Graph, node: torch.fx.Node
@@ -1263,18 +1299,19 @@ def targets(self) -> list[EdgeOpOverload]:
12631299
def _transpose_dims(
12641300
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
12651301
) -> torch.fx.Node:
1266-
"""Helper function to transpose dims of a node."""
1302+
"""Helper function to transpose dims of a node using permute."""
12671303
shape = node.meta["val"].shape
12681304
dim0, dim1 = (
12691305
canonicalize_transposed_dim(dim0, shape),
12701306
canonicalize_transposed_dim(dim1, shape),
12711307
)
12721308
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1273-
transpose_node = graph.call_function(
1274-
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1309+
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
1310+
permute_node = graph.call_function(
1311+
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
12751312
)
1276-
transpose_node.meta = node.meta
1277-
return transpose_node
1313+
permute_node.meta = node.meta
1314+
return permute_node
12781315

12791316
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
12801317
# Get the dimension argument
@@ -1526,8 +1563,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15261563
if not channel_last:
15271564
with graph.inserting_before(node):
15281565
linear_res = graph.call_function(
1529-
exir_ops.edge.aten.transpose_copy.int,
1530-
args=(linear_res, 1, 2),
1566+
exir_ops.edge.aten.permute_copy.default,
1567+
args=(linear_res, [0, 2, 1]),
15311568
)
15321569
linear_res.meta = node.meta
15331570

@@ -1717,8 +1754,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
17171754
if not channel_last:
17181755
with graph.inserting_before(node):
17191756
linear_res = graph.call_function(
1720-
exir_ops.edge.aten.transpose_copy.int,
1721-
args=(linear_res, 1, 2),
1757+
exir_ops.edge.aten.permute_copy.default,
1758+
args=(linear_res, [0, 2, 1]),
17221759
)
17231760
linear_res.meta = node.meta
17241761

@@ -2375,9 +2412,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
23752412

23762413
# Transpose Y_arg
23772414
with graph.inserting_before(node):
2415+
Y_ndim = len(Y_tensor_val.shape)
2416+
perm = transpose_dims_to_permute_order(Y_ndim, -1, -2)
23782417
Y_arg_t = graph.call_function(
2379-
exir_ops.edge.aten.transpose_copy.int,
2380-
args=(Y_arg, -1, -2),
2418+
exir_ops.edge.aten.permute_copy.default,
2419+
args=(Y_arg, perm),
23812420
)
23822421
Y_arg_t.meta = node.meta
23832422

@@ -2410,13 +2449,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24102449
result = super().call(graph_module)
24112450
modified = modified or result.modified
24122451
if modified:
2413-
# Fuse any inserted transpose node with transpose/permute nodes
2452+
# Fuse any inserted permute node with transpose/permute nodes
24142453
# surrounding it.
24152454
result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module)
24162455
modified = modified or result.modified
2417-
# Replace permute with transpose.
2418-
result = ReplacePermuteWithTransposePass().call(result.graph_module)
2419-
modified = modified or result.modified
24202456

24212457
return PassResult(result.graph_module, modified)
24222458

@@ -2640,10 +2676,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
26402676
# graph with another.
26412677
class CadenceReplaceOpsInGraph:
26422678
passes = CommonReplacePasses.passes + [
2679+
ReplaceTransposeWithPermutePass,
26432680
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
26442681
ReplaceEmptyTensorsWithFullPass,
26452682
ReplaceFunctionallyEquivalentOpTargets,
2646-
ReplacePermuteWithTransposePass,
26472683
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
26482684
ReplaceAddMMWithLinearPass,
26492685
ReplacePadWithCatPass,
@@ -2657,6 +2693,8 @@ class CadenceReplaceOpsInGraph:
26572693
ReplaceIm2RowWithViewPass,
26582694
MakeSliceAndCatDimOutermostPass,
26592695
ReplaceMatmulWithTransposedMatmulPass,
2696+
# Convert permutes back to transposes after all passes that create them.
2697+
ReplacePermuteWithTransposePass,
26602698
ReplaceNopTransposeOrPermuteWithViewPass,
26612699
ReplaceLinearWithFullyConnectedOpPass,
26622700
ReplaceScalarTensorWithFullPass,

0 commit comments

Comments
 (0)