Skip to content

Commit 2d31676

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Canonicalize on permutes (#18822)
Summary: First step toward canonicalizing on permute in the graph compiler passes. Reviewed By: ethansfng Differential Revision: D100379751
1 parent 490ec5c commit 2d31676

4 files changed

Lines changed: 137 additions & 153 deletions

File tree

backends/cadence/aot/compiler_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
8787
)
8888

8989

90+
def transpose_dims_to_permute_order(ndim: int, dim0: int, dim1: int) -> List[int]:
91+
"""
92+
Convert transpose(dim0, dim1) to an equivalent permute order list.
93+
E.g., transpose(0, 1) on a 3D tensor gives [1, 0, 2].
94+
"""
95+
# Normalize negative dims
96+
if dim0 < 0:
97+
dim0 += ndim
98+
if dim1 < 0:
99+
dim1 += ndim
100+
order = list(range(ndim))
101+
order[dim0], order[dim1] = order[dim1], order[dim0]
102+
return order
103+
104+
90105
def get_transposed_dims(
91106
node: torch.fx.Node, dims: Optional[List[int]] = None
92107
) -> List[int]:

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,10 +686,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
686686
quant_node,
687687
)
688688
elif isinstance(pattern, AddmmPattern):
689-
# Transpose the weight tensor
689+
# Transpose the weight tensor using permute
690+
weight_ndim = len(weights_inputs[0].meta["val"].shape)
691+
perm = list(range(weight_ndim))
692+
perm[0], perm[1] = perm[1], perm[0]
690693
transposed_weights = graph_module.graph.call_function(
691-
torch.ops.aten.transpose.int,
692-
(weights_inputs[0], 0, 1),
694+
torch.ops.aten.permute.default,
695+
(weights_inputs[0], perm),
693696
)
694697
assert (
695698
"val" in weights_inputs[0].meta
@@ -700,7 +703,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
700703
), "fake_mode is None on weight node"
701704
with original_val.fake_mode:
702705
transposed_weights.meta["val"] = (
703-
torch.ops.aten.transpose.int(original_val, 0, 1)
706+
torch.ops.aten.permute.default(original_val, perm)
704707
)
705708
copy_node_metadata(transposed_weights, weights_inputs[0])
706709

backends/cadence/aot/replace_ops.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
import torch
2121
import torch.fx
22-
from executorch.backends.cadence.aot.compiler_utils import quantize_tensor_multiplier
22+
from executorch.backends.cadence.aot.compiler_utils import (
23+
quantize_tensor_multiplier,
24+
transpose_dims_to_permute_order,
25+
)
2326
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedTransposeOrPermuteOps
2427
from executorch.backends.cadence.aot.pass_utils import (
2528
CadencePassAttribute,
@@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
355358

356359
# Handle transpose: if mat2 is a transpose op, extract the original tensor
357360
transposed_mat2 = False
358-
if (
359-
mat2.op == "call_function"
360-
and mat2.target == exir_ops.edge.aten.transpose_copy.int
361+
if mat2.op == "call_function" and (
362+
mat2.target == exir_ops.edge.aten.transpose_copy.int
363+
or mat2.target == exir_ops.edge.aten.permute_copy.default
361364
):
362365
# mat2 is already transposed, so we use the input to the transpose
363366
mat2 = cast(torch.fx.Node, mat2.args[0])
@@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
405408
# Transpose mat2 if it wasn't already transposed
406409
if not transposed_mat2:
407410
with graph.inserting_before(node):
411+
ndim = len(mat2.meta["val"].shape)
412+
perm = transpose_dims_to_permute_order(ndim, -1, -2)
408413
mat2 = graph.call_function(
409-
exir_ops.edge.aten.transpose_copy.int,
410-
args=(mat2, -1, -2),
414+
exir_ops.edge.aten.permute_copy.default,
415+
args=(mat2, perm),
411416
)
412417

413418
# Metadata copy important
@@ -430,6 +435,35 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
430435
return True
431436

432437

438+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
439+
class ReplaceTransposeWithPermutePass(RemoveOrReplacePassInterface):
440+
"""
441+
Replace transpose_copy.int ops with equivalent permute_copy.default ops
442+
to canonicalize on permute as the single layout-change op.
443+
"""
444+
445+
@property
446+
def targets(self) -> list[EdgeOpOverload]:
447+
return [exir_ops.edge.aten.transpose_copy.int]
448+
449+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
450+
in_tensor = node.args[0]
451+
assert isinstance(in_tensor, torch.fx.Node)
452+
ndim = len(in_tensor.meta["val"].shape)
453+
dim0 = cast(int, node.args[1])
454+
dim1 = cast(int, node.args[2])
455+
perm = transpose_dims_to_permute_order(ndim, dim0, dim1)
456+
457+
with node.graph.inserting_before(node):
458+
new_node = node.graph.call_function(
459+
exir_ops.edge.aten.permute_copy.default,
460+
args=(in_tensor, perm),
461+
)
462+
new_node.meta = node.meta
463+
node.replace_all_uses_with(new_node)
464+
return True
465+
466+
433467
@register_cadence_pass(CadencePassAttribute(opt_level=1))
434468
class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface):
435469
"""
@@ -798,9 +832,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
798832
# gather stencil. Also, the first two dimensions of weight must be
799833
# transposed/interchanged.
800834
assert isinstance(weight, torch.fx.Node)
835+
weight_ndim = len(weight.meta["val"].shape)
836+
perm = transpose_dims_to_permute_order(weight_ndim, 0, 1)
801837
transposed_weight = node.graph.call_function(
802-
exir_ops.edge.aten.transpose_copy.int,
803-
args=(weight, 0, 1),
838+
exir_ops.edge.aten.permute_copy.default,
839+
args=(weight, perm),
804840
)
805841
transposed_weight.meta = weight.meta
806842

@@ -1037,18 +1073,19 @@ def targets(self) -> list[EdgeOpOverload]:
10371073
def _transpose_dims(
10381074
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
10391075
) -> torch.fx.Node:
1040-
"""Helper function to transpose dims of a node."""
1076+
"""Helper function to transpose dims of a node using permute."""
10411077
shape = node.meta["val"].shape
10421078
dim0, dim1 = (
10431079
canonicalize_transposed_dim(dim0, shape),
10441080
canonicalize_transposed_dim(dim1, shape),
10451081
)
10461082
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1047-
transpose_node = graph.call_function(
1048-
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1083+
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
1084+
permute_node = graph.call_function(
1085+
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
10491086
)
1050-
transpose_node.meta = node.meta
1051-
return transpose_node
1087+
permute_node.meta = node.meta
1088+
return permute_node
10521089

10531090
def _change_nchw_to_nhwc(
10541091
self, graph: torch.fx.Graph, node: torch.fx.Node
@@ -1273,18 +1310,19 @@ def targets(self) -> list[EdgeOpOverload]:
12731310
def _transpose_dims(
12741311
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
12751312
) -> torch.fx.Node:
1276-
"""Helper function to transpose dims of a node."""
1313+
"""Helper function to transpose dims of a node using permute."""
12771314
shape = node.meta["val"].shape
12781315
dim0, dim1 = (
12791316
canonicalize_transposed_dim(dim0, shape),
12801317
canonicalize_transposed_dim(dim1, shape),
12811318
)
12821319
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1283-
transpose_node = graph.call_function(
1284-
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1320+
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
1321+
permute_node = graph.call_function(
1322+
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
12851323
)
1286-
transpose_node.meta = node.meta
1287-
return transpose_node
1324+
permute_node.meta = node.meta
1325+
return permute_node
12881326

12891327
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
12901328
# Get the dimension argument
@@ -1536,8 +1574,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15361574
if not channel_last:
15371575
with graph.inserting_before(node):
15381576
linear_res = graph.call_function(
1539-
exir_ops.edge.aten.transpose_copy.int,
1540-
args=(linear_res, 1, 2),
1577+
exir_ops.edge.aten.permute_copy.default,
1578+
args=(linear_res, [0, 2, 1]),
15411579
)
15421580
linear_res.meta = node.meta
15431581

@@ -1727,8 +1765,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
17271765
if not channel_last:
17281766
with graph.inserting_before(node):
17291767
linear_res = graph.call_function(
1730-
exir_ops.edge.aten.transpose_copy.int,
1731-
args=(linear_res, 1, 2),
1768+
exir_ops.edge.aten.permute_copy.default,
1769+
args=(linear_res, [0, 2, 1]),
17321770
)
17331771
linear_res.meta = node.meta
17341772

@@ -2391,9 +2429,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
23912429

23922430
# Transpose Y_arg
23932431
with graph.inserting_before(node):
2432+
Y_ndim = len(Y_tensor_val.shape)
2433+
perm = transpose_dims_to_permute_order(Y_ndim, -1, -2)
23942434
Y_arg_t = graph.call_function(
2395-
exir_ops.edge.aten.transpose_copy.int,
2396-
args=(Y_arg, -1, -2),
2435+
exir_ops.edge.aten.permute_copy.default,
2436+
args=(Y_arg, perm),
23972437
)
23982438
Y_arg_t.meta = node.meta
23992439

@@ -2426,13 +2466,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24262466
result = super().call(graph_module)
24272467
modified = modified or result.modified
24282468
if modified:
2429-
# Fuse any inserted transpose node with transpose/permute nodes
2469+
# Fuse any inserted permute node with transpose/permute nodes
24302470
# surrounding it.
24312471
result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module)
24322472
modified = modified or result.modified
2433-
# Replace permute with transpose.
2434-
result = ReplacePermuteWithTransposePass().call(result.graph_module)
2435-
modified = modified or result.modified
24362473

24372474
return PassResult(result.graph_module, modified)
24382475

@@ -2656,10 +2693,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
26562693
# graph with another.
26572694
class CadenceReplaceOpsInGraph:
26582695
passes = CommonReplacePasses.passes + [
2696+
ReplaceTransposeWithPermutePass,
26592697
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
26602698
ReplaceEmptyTensorsWithFullPass,
26612699
ReplaceFunctionallyEquivalentOpTargets,
2662-
ReplacePermuteWithTransposePass,
26632700
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
26642701
ReplaceAddMMWithLinearPass,
26652702
ReplacePadWithCatPass,
@@ -2673,6 +2710,8 @@ class CadenceReplaceOpsInGraph:
26732710
ReplaceIm2RowWithViewPass,
26742711
MakeSliceAndCatDimOutermostPass,
26752712
ReplaceMatmulWithTransposedMatmulPass,
2713+
# Convert permutes back to transposes after all passes that create them.
2714+
ReplacePermuteWithTransposePass,
26762715
ReplaceNopTransposeOrPermuteWithViewPass,
26772716
ReplaceLinearWithFullyConnectedOpPass,
26782717
ReplaceScalarTensorWithFullPass,

0 commit comments

Comments
 (0)