Skip to content

Commit 635dd46

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Produce permutes instead of transposes
Summary: First step toward canonicalizing on permute in the graph compiler passes. Differential Revision: D100379751
1 parent 74403e2 commit 635dd46

4 files changed

Lines changed: 62 additions & 147 deletions

File tree

backends/cadence/aot/compiler_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
8787
)
8888

8989

90+
def transpose_dims_to_permute_order(ndim: int, dim0: int, dim1: int) -> List[int]:
91+
"""
92+
Convert transpose(dim0, dim1) to an equivalent permute order list.
93+
E.g., transpose(0, 1) on a 3D tensor gives [1, 0, 2].
94+
"""
95+
# Normalize negative dims
96+
if dim0 < 0:
97+
dim0 += ndim
98+
if dim1 < 0:
99+
dim1 += ndim
100+
order = list(range(ndim))
101+
order[dim0], order[dim1] = order[dim1], order[dim0]
102+
return order
103+
104+
90105
def get_transposed_dims(
91106
node: torch.fx.Node, dims: Optional[List[int]] = None
92107
) -> List[int]:

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
678678
quant_node,
679679
)
680680
elif isinstance(pattern, AddmmPattern):
681-
# Transpose the weight tensor
681+
# Transpose the weight tensor using permute
682+
weight_ndim = len(weights_inputs[0].meta["val"].shape)
683+
perm = list(range(weight_ndim))
684+
perm[0], perm[1] = perm[1], perm[0]
682685
transposed_weights = graph_module.graph.call_function(
683-
torch.ops.aten.transpose.int,
684-
(weights_inputs[0], 0, 1),
686+
torch.ops.aten.permute.default,
687+
(weights_inputs[0], perm),
685688
)
686689
assert (
687690
"val" in weights_inputs[0].meta
@@ -692,7 +695,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
692695
), "fake_mode is None on weight node"
693696
with original_val.fake_mode:
694697
transposed_weights.meta["val"] = (
695-
torch.ops.aten.transpose.int(original_val, 0, 1)
698+
torch.ops.aten.permute.default(original_val, perm)
696699
)
697700
copy_node_metadata(transposed_weights, weights_inputs[0])
698701

backends/cadence/aot/replace_ops.py

Lines changed: 36 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
import torch
2121
import torch.fx
22-
from executorch.backends.cadence.aot.compiler_utils import quantize_tensor_multiplier
22+
from executorch.backends.cadence.aot.compiler_utils import (
23+
quantize_tensor_multiplier,
24+
transpose_dims_to_permute_order,
25+
)
2326
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedTransposeOrPermuteOps
2427
from executorch.backends.cadence.aot.pass_utils import (
2528
CadencePassAttribute,
@@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
355358

356359
# Handle transpose: if mat2 is a transpose op, extract the original tensor
357360
transposed_mat2 = False
358-
if (
359-
mat2.op == "call_function"
360-
and mat2.target == exir_ops.edge.aten.transpose_copy.int
361+
if mat2.op == "call_function" and (
362+
mat2.target == exir_ops.edge.aten.transpose_copy.int
363+
or mat2.target == exir_ops.edge.aten.permute_copy.default
361364
):
362365
# mat2 is already transposed, so we use the input to the transpose
363366
mat2 = cast(torch.fx.Node, mat2.args[0])
@@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
405408
# Transpose mat2 if it wasn't already transposed
406409
if not transposed_mat2:
407410
with graph.inserting_before(node):
411+
ndim = len(mat2.meta["val"].shape)
412+
perm = transpose_dims_to_permute_order(ndim, -1, -2)
408413
mat2 = graph.call_function(
409-
exir_ops.edge.aten.transpose_copy.int,
410-
args=(mat2, -1, -2),
414+
exir_ops.edge.aten.permute_copy.default,
415+
args=(mat2, perm),
411416
)
412417

413418
# Metadata copy important
@@ -430,47 +435,6 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
430435
return True
431436

432437

433-
@register_cadence_pass(CadencePassAttribute(opt_level=1))
434-
class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface):
435-
"""
436-
Replace permute op with transpose if the permutation is only along
437-
two dimensions.
438-
"""
439-
440-
@property
441-
def targets(self) -> list[EdgeOpOverload]:
442-
return [exir_ops.edge.aten.permute_copy.default]
443-
444-
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
445-
# Get the old dim and new dim order
446-
in_tensor = node.args[0]
447-
assert isinstance(in_tensor, torch.fx.Node)
448-
in_shape = in_tensor.meta["val"].shape
449-
old_dims = tuple(range(len(in_shape)))
450-
new_dims = cast(Sequence[int], node.args[1])
451-
452-
# Compute the number of positions in which the old and new order differ
453-
diff = [od for od, nd in zip(old_dims, new_dims) if od != nd]
454-
455-
# If the difference is zero, replace with identity (just the input)
456-
if len(diff) == 0:
457-
node.replace_all_uses_with(in_tensor)
458-
return True
459-
460-
# If the difference is in two dimensions, we can replace this permute op
461-
# with transpose op.
462-
if len(diff) == 2:
463-
with node.graph.inserting_before(node):
464-
new_node = node.graph.call_function(
465-
exir_ops.edge.aten.transpose_copy.int,
466-
args=(node.args[0], diff[0], diff[1]),
467-
)
468-
new_node.meta = node.meta
469-
node.replace_all_uses_with(new_node)
470-
return True
471-
472-
return False
473-
474438

475439
@register_cadence_pass(CadencePassAttribute(opt_level=0))
476440
class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface):
@@ -798,9 +762,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
798762
# gather stencil. Also, the first two dimensions of weight must be
799763
# transposed/interchanged.
800764
assert isinstance(weight, torch.fx.Node)
765+
weight_ndim = len(weight.meta["val"].shape)
766+
perm = transpose_dims_to_permute_order(weight_ndim, 0, 1)
801767
transposed_weight = node.graph.call_function(
802-
exir_ops.edge.aten.transpose_copy.int,
803-
args=(weight, 0, 1),
768+
exir_ops.edge.aten.permute_copy.default,
769+
args=(weight, perm),
804770
)
805771
transposed_weight.meta = weight.meta
806772

@@ -1036,18 +1002,19 @@ def targets(self) -> list[EdgeOpOverload]:
10361002
def _transpose_dims(
10371003
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
10381004
) -> torch.fx.Node:
1039-
"""Helper function to transpose dims of a node."""
1005+
"""Helper function to transpose dims of a node using permute."""
10401006
shape = node.meta["val"].shape
10411007
dim0, dim1 = (
10421008
canonicalize_transposed_dim(dim0, shape),
10431009
canonicalize_transposed_dim(dim1, shape),
10441010
)
10451011
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1046-
transpose_node = graph.call_function(
1047-
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1012+
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
1013+
permute_node = graph.call_function(
1014+
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
10481015
)
1049-
transpose_node.meta = node.meta
1050-
return transpose_node
1016+
permute_node.meta = node.meta
1017+
return permute_node
10511018

10521019
def _change_nchw_to_nhwc(
10531020
self, graph: torch.fx.Graph, node: torch.fx.Node
@@ -1263,18 +1230,19 @@ def targets(self) -> list[EdgeOpOverload]:
12631230
def _transpose_dims(
12641231
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
12651232
) -> torch.fx.Node:
1266-
"""Helper function to transpose dims of a node."""
1233+
"""Helper function to transpose dims of a node using permute."""
12671234
shape = node.meta["val"].shape
12681235
dim0, dim1 = (
12691236
canonicalize_transposed_dim(dim0, shape),
12701237
canonicalize_transposed_dim(dim1, shape),
12711238
)
12721239
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1273-
transpose_node = graph.call_function(
1274-
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
1240+
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
1241+
permute_node = graph.call_function(
1242+
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
12751243
)
1276-
transpose_node.meta = node.meta
1277-
return transpose_node
1244+
permute_node.meta = node.meta
1245+
return permute_node
12781246

12791247
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
12801248
# Get the dimension argument
@@ -1526,8 +1494,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
15261494
if not channel_last:
15271495
with graph.inserting_before(node):
15281496
linear_res = graph.call_function(
1529-
exir_ops.edge.aten.transpose_copy.int,
1530-
args=(linear_res, 1, 2),
1497+
exir_ops.edge.aten.permute_copy.default,
1498+
args=(linear_res, [0, 2, 1]),
15311499
)
15321500
linear_res.meta = node.meta
15331501

@@ -1717,8 +1685,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
17171685
if not channel_last:
17181686
with graph.inserting_before(node):
17191687
linear_res = graph.call_function(
1720-
exir_ops.edge.aten.transpose_copy.int,
1721-
args=(linear_res, 1, 2),
1688+
exir_ops.edge.aten.permute_copy.default,
1689+
args=(linear_res, [0, 2, 1]),
17221690
)
17231691
linear_res.meta = node.meta
17241692

@@ -2375,9 +2343,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
23752343

23762344
# Transpose Y_arg
23772345
with graph.inserting_before(node):
2346+
Y_ndim = len(Y_tensor_val.shape)
2347+
perm = transpose_dims_to_permute_order(Y_ndim, -1, -2)
23782348
Y_arg_t = graph.call_function(
2379-
exir_ops.edge.aten.transpose_copy.int,
2380-
args=(Y_arg, -1, -2),
2349+
exir_ops.edge.aten.permute_copy.default,
2350+
args=(Y_arg, perm),
23812351
)
23822352
Y_arg_t.meta = node.meta
23832353

@@ -2410,13 +2380,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
24102380
result = super().call(graph_module)
24112381
modified = modified or result.modified
24122382
if modified:
2413-
# Fuse any inserted transpose node with transpose/permute nodes
2383+
# Fuse any inserted permute node with transpose/permute nodes
24142384
# surrounding it.
24152385
result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module)
24162386
modified = modified or result.modified
2417-
# Replace permute with transpose.
2418-
result = ReplacePermuteWithTransposePass().call(result.graph_module)
2419-
modified = modified or result.modified
24202387

24212388
return PassResult(result.graph_module, modified)
24222389

@@ -2643,7 +2610,6 @@ class CadenceReplaceOpsInGraph:
26432610
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
26442611
ReplaceEmptyTensorsWithFullPass,
26452612
ReplaceFunctionallyEquivalentOpTargets,
2646-
ReplacePermuteWithTransposePass,
26472613
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
26482614
ReplaceAddMMWithLinearPass,
26492615
ReplacePadWithCatPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 4 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
ReplaceMulTensorWithMulAndFullOpsPass,
3838
ReplaceNopTransposeOrPermuteWithViewPass,
3939
ReplacePadWithCatPass,
40-
ReplacePermuteWithTransposePass,
4140
ReplacePowWithMulPass,
4241
ReplaceRepeatWithCatPass,
4342
ReplaceScalarTensorWithFullPass,
@@ -965,10 +964,7 @@ def test_replace_linear_with_fully_connected(self) -> None:
965964
builder.output([mm])
966965
original_gm = builder.get_graph_module()
967966

968-
gm = cast(
969-
PassResult, ReplacePermuteWithTransposePass()(original_gm)
970-
).graph_module
971-
gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module
967+
gm = cast(PassResult, ReplaceMMWithAddMMPass()(original_gm)).graph_module
972968

973969
gm_before_linear = copy.deepcopy(gm)
974970
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
@@ -1029,12 +1025,8 @@ def test_replace_addmm_with_linear(self) -> None:
10291025
builder.output([addmm])
10301026
original_gm = builder.get_graph_module()
10311027

1032-
gm = cast(
1033-
PassResult, ReplacePermuteWithTransposePass()(original_gm)
1034-
).graph_module
1035-
1036-
gm_before_linear = copy.deepcopy(gm)
1037-
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
1028+
gm_before_linear = copy.deepcopy(original_gm)
1029+
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(original_gm))
10381030
self.assertTrue(pass_result.modified)
10391031
graph_after_passes = pass_result.graph_module
10401032

@@ -1077,11 +1069,7 @@ def test_no_replace_addmm_with_linear(self) -> None:
10771069
builder.output([addmm])
10781070
original_gm = builder.get_graph_module()
10791071

1080-
gm = cast(
1081-
PassResult, ReplacePermuteWithTransposePass()(original_gm)
1082-
).graph_module
1083-
1084-
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
1072+
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(original_gm))
10851073
self.assertFalse(pass_result.modified)
10861074

10871075
@torch.no_grad()
@@ -1715,63 +1703,6 @@ def test_replace_nop_permute_with_view(
17151703
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1
17161704
)
17171705

1718-
@expand(
1719-
[
1720-
# permutations replaced by transpose
1721-
[(3, 4), (1, 0)],
1722-
[(3, 4, 6), (0, 2, 1)],
1723-
]
1724-
)
1725-
@torch.no_grad()
1726-
def test_replace_permute_with_transpose(
1727-
self, shape: Tuple[int], dims: Tuple[int]
1728-
) -> None:
1729-
x = torch.randn(shape)
1730-
original_gm = single_op_builder(
1731-
placeholders=(x,),
1732-
op=exir_ops.edge.aten.permute_copy.default,
1733-
args=(x, dims),
1734-
)
1735-
1736-
gm_before = copy.deepcopy(original_gm)
1737-
p = ReplacePermuteWithTransposePass()
1738-
result = cast(PassResult, p(original_gm))
1739-
self.assertTrue(result.modified)
1740-
graph_after_passes = result.graph_module
1741-
inputs = [x]
1742-
validate(
1743-
gm_before, graph_after_passes, inputs, "ReplacePermuteWithTransposePass"
1744-
)
1745-
1746-
# Assert that permute op was replaced by a transpose op
1747-
self.assertEqual(
1748-
count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
1749-
)
1750-
self.assertEqual(
1751-
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1
1752-
)
1753-
1754-
@torch.no_grad()
1755-
def test_replace_permute_with_transpose_nop(
1756-
self,
1757-
) -> None:
1758-
x = torch.randn(3, 4)
1759-
original_gm = single_op_builder(
1760-
placeholders=(x,),
1761-
op=exir_ops.edge.aten.permute_copy.default,
1762-
args=(x, [0, 1]),
1763-
)
1764-
p = ReplacePermuteWithTransposePass()
1765-
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
1766-
1767-
# Assert that permute op was replaced by a transpose op
1768-
self.assertEqual(
1769-
count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
1770-
)
1771-
self.assertEqual(
1772-
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0
1773-
)
1774-
17751706

17761707
class TestReplaceWhereWithFullArgsWithWhereScalar(unittest.TestCase):
17771708
def test_replace_aten_where_with_cadence(self) -> None:

0 commit comments

Comments
 (0)