Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
)


def transpose_dims_to_permute_order(ndim: int, dim0: int, dim1: int) -> List[int]:
"""
Convert transpose(dim0, dim1) to an equivalent permute order list.
E.g., transpose(0, 1) on a 3D tensor gives [1, 0, 2].
"""
# Normalize negative dims
if dim0 < 0:
dim0 += ndim
if dim1 < 0:
dim1 += ndim
order = list(range(ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return order


def get_transposed_dims(
node: torch.fx.Node, dims: Optional[List[int]] = None
) -> List[int]:
Expand Down
11 changes: 7 additions & 4 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
quant_node,
)
elif isinstance(pattern, AddmmPattern):
# Transpose the weight tensor
# Transpose the weight tensor using permute
weight_ndim = len(weights_inputs[0].meta["val"].shape)
perm = list(range(weight_ndim))
perm[0], perm[1] = perm[1], perm[0]
transposed_weights = graph_module.graph.call_function(
torch.ops.aten.transpose.int,
(weights_inputs[0], 0, 1),
torch.ops.aten.permute.default,
(weights_inputs[0], perm),
)
assert (
"val" in weights_inputs[0].meta
Expand All @@ -700,7 +703,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
), "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = (
torch.ops.aten.transpose.int(original_val, 0, 1)
torch.ops.aten.permute.default(original_val, perm)
)
copy_node_metadata(transposed_weights, weights_inputs[0])

Expand Down
97 changes: 68 additions & 29 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

import torch
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import quantize_tensor_multiplier
from executorch.backends.cadence.aot.compiler_utils import (
quantize_tensor_multiplier,
transpose_dims_to_permute_order,
)
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedTransposeOrPermuteOps
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
Expand Down Expand Up @@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:

# Handle transpose: if mat2 is a transpose op, extract the original tensor
transposed_mat2 = False
if (
mat2.op == "call_function"
and mat2.target == exir_ops.edge.aten.transpose_copy.int
if mat2.op == "call_function" and (
mat2.target == exir_ops.edge.aten.transpose_copy.int
or mat2.target == exir_ops.edge.aten.permute_copy.default
):
# mat2 is already transposed, so we use the input to the transpose
mat2 = cast(torch.fx.Node, mat2.args[0])
Expand Down Expand Up @@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Transpose mat2 if it wasn't already transposed
if not transposed_mat2:
with graph.inserting_before(node):
ndim = len(mat2.meta["val"].shape)
perm = transpose_dims_to_permute_order(ndim, -1, -2)
mat2 = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(mat2, -1, -2),
exir_ops.edge.aten.permute_copy.default,
args=(mat2, perm),
)

# Metadata copy important
Expand All @@ -430,6 +435,35 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplaceTransposeWithPermutePass(RemoveOrReplacePassInterface):
"""
Replace transpose_copy.int ops with equivalent permute_copy.default ops
to canonicalize on permute as the single layout-change op.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.transpose_copy.int]

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
in_tensor = node.args[0]
assert isinstance(in_tensor, torch.fx.Node)
ndim = len(in_tensor.meta["val"].shape)
dim0 = cast(int, node.args[1])
dim1 = cast(int, node.args[2])
perm = transpose_dims_to_permute_order(ndim, dim0, dim1)

with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.permute_copy.default,
args=(in_tensor, perm),
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface):
"""
Expand Down Expand Up @@ -798,9 +832,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# gather stencil. Also, the first two dimensions of weight must be
# transposed/interchanged.
assert isinstance(weight, torch.fx.Node)
weight_ndim = len(weight.meta["val"].shape)
perm = transpose_dims_to_permute_order(weight_ndim, 0, 1)
transposed_weight = node.graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(weight, 0, 1),
exir_ops.edge.aten.permute_copy.default,
args=(weight, perm),
)
transposed_weight.meta = weight.meta

Expand Down Expand Up @@ -1037,18 +1073,19 @@ def targets(self) -> list[EdgeOpOverload]:
def _transpose_dims(
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
) -> torch.fx.Node:
"""Helper function to transpose dims of a node."""
"""Helper function to transpose dims of a node using permute."""
shape = node.meta["val"].shape
dim0, dim1 = (
canonicalize_transposed_dim(dim0, shape),
canonicalize_transposed_dim(dim1, shape),
)
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
transpose_node = graph.call_function(
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
)
transpose_node.meta = node.meta
return transpose_node
permute_node.meta = node.meta
return permute_node

def _change_nchw_to_nhwc(
self, graph: torch.fx.Graph, node: torch.fx.Node
Expand Down Expand Up @@ -1273,18 +1310,19 @@ def targets(self) -> list[EdgeOpOverload]:
def _transpose_dims(
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
) -> torch.fx.Node:
"""Helper function to transpose dims of a node."""
"""Helper function to transpose dims of a node using permute."""
shape = node.meta["val"].shape
dim0, dim1 = (
canonicalize_transposed_dim(dim0, shape),
canonicalize_transposed_dim(dim1, shape),
)
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
transpose_node = graph.call_function(
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
)
transpose_node.meta = node.meta
return transpose_node
permute_node.meta = node.meta
return permute_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the dimension argument
Expand Down Expand Up @@ -1536,8 +1574,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
if not channel_last:
with graph.inserting_before(node):
linear_res = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(linear_res, 1, 2),
exir_ops.edge.aten.permute_copy.default,
args=(linear_res, [0, 2, 1]),
)
linear_res.meta = node.meta

Expand Down Expand Up @@ -1727,8 +1765,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
if not channel_last:
with graph.inserting_before(node):
linear_res = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(linear_res, 1, 2),
exir_ops.edge.aten.permute_copy.default,
args=(linear_res, [0, 2, 1]),
)
linear_res.meta = node.meta

Expand Down Expand Up @@ -2391,9 +2429,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:

# Transpose Y_arg
with graph.inserting_before(node):
Y_ndim = len(Y_tensor_val.shape)
perm = transpose_dims_to_permute_order(Y_ndim, -1, -2)
Y_arg_t = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(Y_arg, -1, -2),
exir_ops.edge.aten.permute_copy.default,
args=(Y_arg, perm),
)
Y_arg_t.meta = node.meta

Expand Down Expand Up @@ -2426,13 +2466,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
result = super().call(graph_module)
modified = modified or result.modified
if modified:
# Fuse any inserted transpose node with transpose/permute nodes
# Fuse any inserted permute node with transpose/permute nodes
# surrounding it.
result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module)
modified = modified or result.modified
# Replace permute with transpose.
result = ReplacePermuteWithTransposePass().call(result.graph_module)
modified = modified or result.modified

return PassResult(result.graph_module, modified)

Expand Down Expand Up @@ -2656,10 +2693,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# graph with another.
class CadenceReplaceOpsInGraph:
passes = CommonReplacePasses.passes + [
ReplaceTransposeWithPermutePass,
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplacePermuteWithTransposePass,
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
ReplaceAddMMWithLinearPass,
ReplacePadWithCatPass,
Expand All @@ -2673,6 +2710,8 @@ class CadenceReplaceOpsInGraph:
ReplaceIm2RowWithViewPass,
MakeSliceAndCatDimOutermostPass,
ReplaceMatmulWithTransposedMatmulPass,
# Convert permutes back to transposes after all passes that create them.
ReplacePermuteWithTransposePass,
ReplaceNopTransposeOrPermuteWithViewPass,
ReplaceLinearWithFullyConnectedOpPass,
ReplaceScalarTensorWithFullPass,
Expand Down
Loading
Loading