From 59c06be75587639902d1d8f763e88e879b03b4a3 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Tue, 21 Apr 2026 11:06:43 -0700 Subject: [PATCH] Move permute optimization passes to shared transforms location (#19002) Summary: Move 6 permute optimization passes and their shared infrastructure from executorch/backends/cadence/aot/ to executorch/backends/transforms/ so they can be shared between the Cadence and Arm backends without a cross-backend dependency. New files: - permute_pass_utils.py: base classes (HierarchicalInplacePassInterface, RemoveOrReplacePassInterface, FuseOpPairsAcrossBranchesPass) and utilities (get_arg, set_arg, get_transposed_dims, get_permuted_dims, get_shape, get_edge_overload_packet) - fuse_cascaded_transpose_or_permute_ops.py - fuse_cascaded_view_ops.py - fuse_transpose_or_permute_op_pairs_pass.py - remove_permutes_around_elementwise_ops.py - postpone_permute_below_squeeze_view.py - replace_nop_transpose_or_permute_with_view.py The shared versions omit register_cadence_pass decorators and cadence-specific ops from default op sets. Cadence files will subclass these and re-add the decorators and ops. Added OSS tests (test_permute_optimization_passes.py) for the 4 passes that can be imported without quantized op registration: FuseCascadedTransposeOrPermuteOps, FuseCascadedViewOps, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, and ReplaceNopTransposeOrPermuteWithViewPass. These run in GitHub CI via pytest and are discovered automatically through pytest.ini testpaths. Reviewed By: ethansfng Differential Revision: D101459577 --- backends/cadence/aot/BUCK | 10 +- backends/cadence/aot/fuse_ops.py | 311 ++---------- backends/cadence/aot/pass_utils.py | 164 +------ backends/cadence/aot/remove_ops.py | 276 +---------- backends/cadence/aot/reorder_ops.py | 195 +------- backends/cadence/aot/replace_ops.py | 78 +--- backends/cadence/aot/tests/test_pass_utils.py | 20 +- .../fuse_cascaded_transpose_or_permute_ops.py | 69 +++ backends/transforms/fuse_cascaded_view_ops.py | 43 ++ ...fuse_transpose_or_permute_op_pairs_pass.py | 102 ++++ backends/transforms/permute_pass_utils.py | 278 +++++++++++ .../postpone_permute_below_squeeze_view.py | 207 ++++++++ .../remove_permutes_around_elementwise_ops.py | 272 +++++++++++ ...lace_nop_transpose_or_permute_with_view.py | 90 ++++ backends/transforms/targets.bzl | 111 +++++ .../test/test_permute_optimization_passes.py | 442 ++++++++++++++++++ 16 files changed, 1697 insertions(+), 971 deletions(-) create mode 100644 backends/transforms/fuse_cascaded_transpose_or_permute_ops.py create mode 100644 backends/transforms/fuse_cascaded_view_ops.py create mode 100644 backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py create mode 100644 backends/transforms/permute_pass_utils.py create mode 100644 backends/transforms/postpone_permute_below_squeeze_view.py create mode 100644 backends/transforms/remove_permutes_around_elementwise_ops.py create mode 100644 backends/transforms/replace_nop_transpose_or_permute_with_view.py create mode 100644 backends/transforms/test/test_permute_optimization_passes.py diff --git a/backends/cadence/aot/BUCK b/backends/cadence/aot/BUCK index 3eb77d4470f..5b5316245f8 100644 --- a/backends/cadence/aot/BUCK +++ b/backends/cadence/aot/BUCK @@ -81,9 +81,9 @@ fbcode_target(_kind = runtime.python_library, "pass_utils.py", ], deps = [ - "fbsource//third-party/pypi/beartype:beartype", ":utils", "//caffe2:torch", + "//executorch/backends/transforms:permute_pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/passes:lib", @@ -188,7 +188,6 @@ fbcode_target(_kind = python_unittest, ], typing = True, deps = [ - "fbsource//third-party/pypi/beartype:beartype", ":pass_utils", "//caffe2:torch", ], @@ -267,6 +266,10 @@ fbcode_target(_kind = runtime.python_library, "//caffe2:torch", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:fuse_cascaded_transpose_or_permute_ops", + "//executorch/backends/transforms:fuse_cascaded_view_ops", + "//executorch/backends/transforms:fuse_transpose_or_permute_op_pairs_pass", + "//executorch/backends/transforms:permute_pass_utils", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", @@ -304,6 +307,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:simplify_ops", "//executorch/backends/transforms:remove_clone_ops", + "//executorch/backends/transforms:remove_permutes_around_elementwise_ops", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", @@ -322,6 +326,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/backends/cadence/aot:compiler_utils", "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:postpone_permute_below_squeeze_view", "//executorch/exir:pass_base", "//executorch/exir:tensor", "//executorch/exir/dialects:lib", @@ -343,6 +348,7 @@ fbcode_target(_kind = runtime.python_library, "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:remove_ops", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:replace_nop_transpose_or_permute_with_view", "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 023a6f5760a..d6ee88e94c6 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -12,9 +12,8 @@ import logging import math import operator -from collections import deque from numbers import Number -from typing import Any, Callable, cast, Optional, override +from typing import Any, cast, Optional, override # Import these for the cadence function signatures. import executorch.backends.cadence.aot.ops_registrations # noqa: F401 @@ -22,10 +21,8 @@ import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( broadcastable, - get_permuted_dims, get_scale, get_tensor_from_attr, - get_transposed_dims, get_zero_point, ) from executorch.backends.cadence.aot.pass_utils import ( @@ -36,9 +33,21 @@ RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( + FuseCascadedTransposeOrPermuteOps as _SharedFuseCascadedTransposeOrPermuteOps, +) +from executorch.backends.transforms.fuse_cascaded_view_ops import ( + FuseCascadedViewOps as _SharedFuseCascadedViewOps, +) +from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import ( + FuseTransposeOrPermuteOpPairsPass as _SharedFuseTransposeOrPermuteOpPairsPass, +) +from executorch.backends.transforms.permute_pass_utils import ( + FuseOpPairsAcrossBranchesPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.pass_base import PassResult from executorch.exir.passes.cse_pass import CSEPass from torch.nn.utils.fusion import fuse_conv_bn_weights @@ -578,207 +587,13 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): - """ - Fuse a chain of transpose and permute ops into a single permute or a no-op. - Handles branches and chains permutes. - """ - - transpose_or_permute_target = { - exir_ops.edge.aten.transpose_copy.int, - exir_ops.edge.aten.permute_copy.default, - } - - @property - def targets(self) -> list[EdgeOpOverload]: - return list(self.transpose_or_permute_target) - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Fuse with the parent node if it's also a permute or a transpose. Since the - # pass interface traverses all ops in order the pass will properly fuse a chain - # of permutes. - parent_node = get_arg(node, "input", torch.fx.Node) - if parent_node.target not in self.transpose_or_permute_target: - return False - input_of_parent = get_arg(parent_node, "input", torch.fx.Node) - - # Compute combined effect of permutes. - dims = list(range(node.meta["val"].ndim)) - - if parent_node.target == exir_ops.edge.aten.transpose_copy.int: - dims = get_transposed_dims(parent_node, dims) - else: - dims = get_permuted_dims(parent_node, dims) - - if node.target == exir_ops.edge.aten.transpose_copy.int: - dims = get_transposed_dims(node, dims) - else: - dims = get_permuted_dims(node, dims) - - # If combined effect is identity replace the node with input. - if dims == sorted(dims): - node.replace_all_uses_with(input_of_parent) - else: - with node.graph.inserting_before(node): - new_permute = node.graph.call_function( - exir_ops.edge.aten.permute_copy.default, - args=(input_of_parent, dims), - ) - new_permute.meta = node.meta - node.replace_all_uses_with(new_permute) - - return True +class FuseCascadedTransposeOrPermuteOps(_SharedFuseCascadedTransposeOrPermuteOps): + pass @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedViewOps(RemoveOrReplacePassInterface): - """ - Fuse a cascaded chain of view ops - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [exir_ops.edge.aten.view_copy.default] - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Check if the input to this view node is also a view node - input_view = node.args[0] - if not isinstance(input_view, torch.fx.Node): - return False - - if ( - input_view.op != "call_function" - or input_view.target != exir_ops.edge.aten.view_copy.default - ): - return False - - # Replace the input of this view node with the input of the cascaded view - # This effectively "skips" the intermediate view node - node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) - return True - - -class FuseOpPairsAcrossBranchesPass(ExportPass): - """ - Base class for passes that fuse op pairs across branches. - Provides common functionality for finding and fusing producer-consumer chains. - """ - - def check_ok_to_fuse( - self, - producer: torch.fx.Node, - consumers: list[torch.fx.Node], - ) -> bool: - # Always ok to replace / remove. - return True - - def can_fuse_for_chain( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - consumer_op_packets: set[EdgeOpOverloadPacket], - ) -> bool: - """ - Returns true if producer and consumer can be fused for a single chain - (-> producer -> ops -> consumer ->) to (-> ops -> fused_op) - """ - if ( - isinstance(consumer.target, EdgeOpOverload) - and get_edge_overload_packet(consumer.target) in consumer_op_packets - ): - return True - return False - - def get_fuse_candidates( - self, - producer: torch.fx.Node, - consumer_op_packets: set[EdgeOpOverloadPacket], - bypass_ops: set[EdgeOpOverload], - ) -> list[torch.fx.Node]: - # Start by iterating over all the users of this node, and check - # if they are have their target in consumer_op_packets. - users = deque(producer.users.keys()) - # This holds the list of the user ops that directly (or transitively - # via view/slice) consume this producer_op_packets, and hence can be removed. - removal_candidates = [] - while users: - user = users.popleft() - - # If the user is a bypass op, we bypass it, and examine - # its users instead for consumer_op_packets. - if user.target in bypass_ops: - users.extend(list(user.users.keys())) - elif self.can_fuse_for_chain(producer, user, consumer_op_packets): - removal_candidates.append(user) - else: - removal_candidates.clear() - break - return removal_candidates - - def find_and_fuse( - self, - graph_module: torch.fx.GraphModule, - producer_op_packets: set[EdgeOpOverloadPacket], - consumer_op_packets: set[EdgeOpOverloadPacket], - bypass_ops: set[EdgeOpOverload], - ) -> bool: - """ - Find and fuse producer-consumer op pairs. - - Returns True if any fusion was performed, False otherwise. - """ - modified = False - for node in graph_module.graph.nodes: - # We are only interested in ops that have overload target in - # producer_op. - if not ( - isinstance(node.target, EdgeOpOverload) - and get_edge_overload_packet(node.target) in producer_op_packets - ): - continue - - removal_candidates = self.get_fuse_candidates( - node, consumer_op_packets, bypass_ops - ) - - if len(removal_candidates) == 0: - # No candidates found. - continue - - if not self.check_ok_to_fuse(node, removal_candidates): - # Not ok to remove quant-dequant pairs or replace with requantize. - continue - - self.fuse(node, removal_candidates, graph_module) - modified = True - - if modified: - graph_module.recompile() - - return modified - - def get_fused_node( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - graph_module: torch.fx.GraphModule, - ) -> torch.fx.Node: - return consumer - - def fuse( - self, - node: torch.fx.Node, - removal_candidates: list[torch.fx.Node], - graph_module: torch.fx.GraphModule, - ) -> None: - # Replace all the uses of the producer op with it's input. - node.replace_all_uses_with(cast(torch.fx.Node, node.args[0])) - graph_module.graph.erase_node(node) - - # Iterate over all the removal candidates (quantize op users) and generate replacements. - for rnode in removal_candidates: - rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module)) - graph_module.graph.erase_node(rnode) +class FuseCascadedViewOps(_SharedFuseCascadedViewOps): + pass @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1123,89 +938,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): - """ - Fuse transpose or permute op pairs to a single view op. - (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) - This happens when op2(op1) == identity, modulo unitary dimensions. - 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] - so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. - """ - - # A list of ops that can be bypassed when looking for a - # dequantize->quantize chain - bypass_ops: set[EdgeOpOverload] = { - exir_ops.edge.cadence.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, - } - - def can_fuse_for_chain( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - consumer_op_packets: set[EdgeOpOverloadPacket], - ) -> bool: - if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): - return False - - # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions - producer_input = cast(torch.fx.Node, producer.args[0]) - if "val" not in producer_input.meta: - return False - input_shape = producer_input.meta["val"].shape - ident_dims = list(range(len(input_shape))) - # this mapping helps to handle both transpose and permutations - f: dict[Any, Callable] = { - exir_ops.edge.aten.transpose_copy.int: get_transposed_dims, - exir_ops.edge.aten.permute_copy.default: get_permuted_dims, +class FuseTransposeOrPermuteOpPairsPass(_SharedFuseTransposeOrPermuteOpPairsPass): + bypass_ops: set[EdgeOpOverload] = ( + _SharedFuseTransposeOrPermuteOpPairsPass.bypass_ops + | { + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantized_relu.per_tensor, } - in_dims = f[producer.target](producer, ident_dims) - out_dims = f[consumer.target](consumer, in_dims) - # Filtering out unitary dimensions - non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1] - non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1] - return non_unit_out_dims == non_unit_ident_dims - - def get_fused_node( - self, - producer: torch.fx.Node, - consumer: torch.fx.Node, - graph_module: torch.fx.GraphModule, - ) -> torch.fx.Node: - # This step is important because of how we can fuse transpositions that are not perfectly - # reverse one of another but will be fused if there are unitary dimensions. - # The fused operation must have the same output shape as the consumer. - output_shape = consumer.meta["val"].shape - with graph_module.graph.inserting_after(consumer): - view = graph_module.graph.call_function( - exir_ops.edge.aten.view_copy.default, - (consumer.args[0], output_shape), - {}, - ) - return view - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - # Remove any transpose/permutation op pair that cancel each other. - modified = self.find_and_fuse( - graph_module, - producer_op_packets={ - exir_ops.edge.aten.transpose_copy, - exir_ops.edge.aten.permute_copy, - }, - consumer_op_packets={ - exir_ops.edge.aten.transpose_copy, - exir_ops.edge.aten.permute_copy, - }, - bypass_ops=self.bypass_ops, - ) - if modified: - return super().call(graph_module) - return PassResult(graph_module, False) + ) @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index d03862d44fa..ab42ef43d56 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -7,20 +7,22 @@ # pyre-strict import dataclasses -from abc import abstractmethod from dataclasses import dataclass -from typing import Callable, List, Optional, override, Set, Type, TypeVar, Union +from typing import Callable, List, Optional, Set, Type, Union import torch -from beartype.door import die_if_unbearable from executorch.backends.cadence.aot.utils import get_edge_overload_packet + +# Re-exported for downstream consumers (noqa for flake8, `as X` for Pyre strict). +from executorch.backends.transforms.permute_pass_utils import ( # noqa: F401 + get_arg as get_arg, + HierarchicalInplacePassInterface as HierarchicalInplacePassInterface, + RemoveOrReplacePassInterface as RemoveOrReplacePassInterface, + set_arg as set_arg, +) from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, PassBase, PassResult +from executorch.exir.pass_base import PassBase, PassResult from torch._ops import OpOverloadPacket -from torch.fx import Node -from torch.fx.node import Argument - -T = TypeVar("T") # Is an overlap in tensor lifetime and storage allowed at the current opt level? @@ -207,152 +209,6 @@ def nodes_not_adjacent_in_gm( return True -def get_arg( - node: torch.fx.Node, - kwarg_name: str, - expected_type: Type[T] = Argument, -) -> T: - """ - Get the arg with arg_name of the node, returns default value if not set. - - Args: - node: The FX node to extract the argument from - kwarg_name: The name of the argument to extract - expected_type: Optional type to validate and cast the argument to. - If provided, asserts the argument is an instance of this type. - - Returns: - The argument value, optionally type-checked and cast to expected_type - - Example: - # Get a node argument with type checking - conv_weight_node = get_arg(node, "weight", torch.fx.Node) - - # Get a float argument with type checking - eps = get_arg(node, "eps", float) - - # Get an argument without type checking (returns Argument) - value = get_arg(node, "some_arg") - """ - # Try to get the arg from kwargs first since this is faster - if kwarg_name in node.kwargs: - value = node.kwargs[kwarg_name] - else: - # If it's not found in kwargs, try to normalize the args - normalized_args = node.normalized_arguments( - node.graph.owning_module, normalize_to_only_use_kwargs=True - ) - if not normalized_args: - raise RuntimeError( - f"get_arg: Node {node} does not support normalization of arguments" - ) - value = normalized_args.kwargs[kwarg_name] - - # Validate type using beartype's runtime type checker when a specific - # type is requested (not the default Argument type alias, which contains - # recursive forward references that beartype cannot resolve). - if expected_type is not Argument: - die_if_unbearable(value, expected_type) - return value # type: ignore[return-value] - - -def set_arg( - node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument -) -> None: - """ - Set the node's arg with its name to the given value. - """ - # Try to set the arg if it is present in kwargs first since this is faster - if kwarg_name in node.kwargs: - node.update_kwarg(kwarg_name, value) - return - - # If it's not found in kwargs, try to normalize the args and set the arg - normalized_args = node.normalized_arguments( - node.graph.owning_module, normalize_to_only_use_kwargs=True - ) - if not normalized_args: - raise RuntimeError( - f"set_arg: Node {node} does not support normalization of arguments" - ) - - kwargs = normalized_args.kwargs - if kwarg_name not in kwargs: - raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used") - - idx = list(kwargs.keys()).index(kwarg_name) - if idx < len(node.args): - node.update_arg(idx, value) - else: - node.update_kwarg(kwarg_name, value) - - def none_throws(x: Optional[PassResult]) -> PassResult: assert x is not None return x - - -class HierarchicalInplacePassInterface(ExportPass): - """A base class for passes that apply in-place modification to the graph module and its submodules. - Also calls ExportPass.call() in case the graph module is modified to ensure all nodes have valid `meta['val']`. - """ - - @abstractmethod - def _apply_flat_inplace(self, graph_module) -> bool: - """Apply in-place modification to the graph module.""" - raise NotImplementedError("`_apply_flat_inplace` must be implemented") - - def _apply_hierarchical_inplace(self, graph_module: torch.fx.GraphModule) -> bool: - """Apply in-place modification recursively to the graph module and its submodules.""" - - modified: bool = False - for module in filter( - lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() - ): - modified |= self._apply_flat_inplace(module) - - return modified - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - modified = self._apply_hierarchical_inplace(graph_module) - - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - return super().call(graph_module) - - return PassResult(graph_module, False) - - -class RemoveOrReplacePassInterface(HierarchicalInplacePassInterface): - @property - @abstractmethod - def targets(self) -> list[EdgeOpOverload]: - """ - The list of targets to potentially remove or replace. - """ - raise NotImplementedError("`targets` must be implemented") - - @abstractmethod - def maybe_remove_or_replace(self, node: Node) -> bool: - """ - If the node should be removed/replaced, removes/replaces from the graph. Returns - True if the graph was modified, else False. - """ - raise NotImplementedError("`maybe_remove_or_replace` must be implemented") - - @override - def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool: - changed = False - for target in self.targets: - for node in graph_module.graph.find_nodes( - op="call_function", target=target - ): - if len(node.users) == 0: - # It is possible that maybe_remove_or_replace would have removed - # this target by starting from a different target. In this case, - # we should ignore it. If it wasn't erased, it will be handled - # in eliminate_dead_code. - continue - changed |= self.maybe_remove_or_replace(node) - return changed diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index a85b13452c1..dabab032116 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -6,7 +6,6 @@ # pyre-strict -from dataclasses import dataclass, field from typing import cast, List, Optional, Sequence, Set, Type # Import these for the cadence function signatures. @@ -26,6 +25,9 @@ from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps as _SharedRemovePermutesAroundElementwiseOps, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from executorch.exir.pass_base import ExportPass, PassResult @@ -386,267 +388,17 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class RemovePermutesAroundElementwiseOps(ExportPass): - """ - Looks for subgraphs of elementwise ops sandwiched between permutes and removes those - permutes if possible. - Allows special handling for certain non-elementwise ops that can be easily updated - based on the permute's parameter such as mean, cat, and slice. - """ - - @dataclass() - class Subgraph: - start_permute: list[int] - end_permute: list[int] - # Nodes in the subgraph, does not include permutes. - nodes: set[torch.fx.Node] = field(default_factory=set) - # Incoming edges to the subgraph from permute nodes. - edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) - # Outgoing edges of the subgraph to permute nodes. - edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) - # Incoming edges from constant nodes that need a compensating permute. - constant_edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field( - default_factory=set - ) - - permutable_ops: set[EdgeOpOverload] = { - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.hardtanh.default, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.cadence.quantize_per_tensor.default, - exir_ops.edge.cadence.dequantize_per_tensor.default, - exir_ops.edge.cadence.quantized_relu.per_tensor, - exir_ops.edge.cadence.requantize.per_tensor, - exir_ops.edge.cadence.quantized_add.per_tensor, - # Ops that require special handling. - exir_ops.edge.aten.cat.default, - exir_ops.edge.aten.mean.dim, - exir_ops.edge.aten.slice_copy.Tensor, - } - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] - processed_nodes: set[torch.fx.Node] = set() - for node in graph_module.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.permute_copy.default - ): - start_permute = self.get_permutation(node) - # Expected end permutation for the subgraph. - end_permute = [start_permute.index(i) for i in range(len(start_permute))] - - for user in node.users: - if user.target not in self.permutable_ops: - continue - # Create a separate subgraph for each user since there may be cases - # where only a portion of the users are permutable. - subgraph = self.Subgraph(start_permute, end_permute) - if self.visit(user, subgraph, processed_nodes): - subgraphs_found.append(subgraph) - for node in subgraph.nodes: - processed_nodes.add(node) - - modified = False - for subgraph in subgraphs_found: - self.permute_subgraph(subgraph) - modified = True - - if modified: - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - return super().call(graph_module) - - return PassResult(graph_module, False) - - def visit( # noqa: C901 - self, - node: torch.fx.Node, - subgraph: Subgraph, - processed_nodes: set[torch.fx.Node], - ) -> bool: - if node in subgraph.nodes: - return True - if node in processed_nodes or not self.is_node_permutable(node): - return False - subgraph.nodes.add(node) - - # Traverse downstream: - for user in node.users: - # Output should either go to a matching permute or another permutable op. - if user.target == exir_ops.edge.aten.permute_copy.default: - if self.get_permutation(user) != subgraph.end_permute: - return False - subgraph.edges_out.add((node, user)) - elif user.op == "output": - # Graph output requires the data in its original layout. - # Removing permutes here would silently change the output - # format, so treat this as an invalid subgraph boundary. - return False - elif not self.visit(user, subgraph, processed_nodes): - return False - - # Traverse upstream: - for inp in node.all_input_nodes: - # Input should either come from a matching permute or another permutable op. - if inp.target == exir_ops.edge.aten.permute_copy.default: - if self.get_permutation(inp) != subgraph.start_permute: - return False - subgraph.edges_in.add((inp, node)) - elif self._is_constant(inp): - # Only accept the constant if we can compensate it with a - # permute or view. Otherwise reject the subgraph. - const_rank = self._get_node_rank(inp) - if const_rank is None: - return False - if const_rank > len(subgraph.end_permute): - return False - if ( - const_rank < len(subgraph.end_permute) - and inp.meta.get("val") is None - ): - return False - subgraph.constant_edges_in.add((inp, node)) - elif not self.visit(inp, subgraph, processed_nodes): - return False - - return True - - def _is_constant(self, node: torch.fx.Node) -> bool: - """Check if a node's value is available at compile time. - Only considers direct constants (get_attr, parameter/buffer/constant - placeholders) — does not recurse into call_function chains to avoid - stack overflow on deep graphs.""" - if node.op == "get_attr": - return True - if node.op == "placeholder": - target = str(node.target) - return target.startswith(("b_", "p_", "c_")) - return False - - def _get_node_rank(self, node: torch.fx.Node) -> int | None: - """Return the tensor rank of a node's output, or None if unknown.""" - val = node.meta.get("val") - if val is not None and hasattr(val, "shape"): - return len(val.shape) - return None - - def is_node_permutable(self, node: torch.fx.Node) -> bool: - if node.target not in self.permutable_ops: - return False - if node.target == exir_ops.edge.aten.mean.dim: - # keepdim should be True. - if len(node.args) >= 3: - if not node.args[2]: - return False - elif "keepdim" in node.kwargs: - if not node.kwargs["keepdim"]: - return False - else: - # Default keepdim is False. - return False - return True - - def permute_subgraph(self, subgraph: Subgraph) -> None: - # Skip incoming permutes. - for inp, out in subgraph.edges_in: - assert inp.target == exir_ops.edge.aten.permute_copy.default - if len(inp.args) >= 1: - out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0])) - else: - out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) - - # Insert compensating permute on constant inputs. - # Since the subgraph's start permutes are being removed, the subgraph - # will operate in the un-permuted (original) layout. Constants that - # were in the permuted layout need end_permute (the inverse of - # start_permute) to convert back to the original layout. - for const_node, user_node in subgraph.constant_edges_in: - graph = const_node.graph - const_rank = self._get_node_rank(const_node) - permute_rank = len(subgraph.end_permute) - - with graph.inserting_after(const_node): - if const_rank is not None and const_rank == permute_rank: - new_node = graph.create_node( - "call_function", - exir_ops.edge.aten.permute_copy.default, - args=(const_node, subgraph.end_permute), - ) - elif ( - const_rank is not None - and const_rank < permute_rank - and const_node.meta.get("val") is not None - ): - # Rank mismatch (e.g. rank-1 bias with rank-4 permute). - # The constant is broadcastable and its shape is smaller - # than the permute rank, so we can't apply the permute - # directly. Instead, use view_copy to rearrange the - # shape according to the end_permute restricted to - # the trailing dimensions. - original_shape = list(const_node.meta["val"].shape) - # Pad shape to match permute rank for reordering - padded = [1] * (permute_rank - const_rank) + original_shape - target_shape = [padded[d] for d in subgraph.end_permute] - # Strip leading 1s back to original rank - target_shape = target_shape[permute_rank - const_rank :] - new_node = graph.create_node( - "call_function", - exir_ops.edge.aten.view_copy.default, - args=(const_node, target_shape), - ) - else: - # Cannot determine rank or handle this case; skip. - continue - user_node.replace_input_with(const_node, new_node) - - # Skip outgoing permutes. - for inp, out in subgraph.edges_out: - assert out.target == exir_ops.edge.aten.permute_copy.default - out.replace_all_uses_with(inp) - - # Handle dimension related node arguments. - for node in subgraph.nodes: - if node.target == exir_ops.edge.aten.cat.default: - self.update_cat(node, subgraph.start_permute) - elif node.target == exir_ops.edge.aten.mean.dim: - self.update_mean_dim(node, subgraph.start_permute) - elif node.target == exir_ops.edge.aten.slice_copy.Tensor: - self.update_slice_copy(node, subgraph.start_permute) - - def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: - if len(node.args) >= 2: - node.update_arg(1, start_permute[cast(int, node.args[1])]) - elif "dim" in node.kwargs: - node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) - else: - # Default cat dim is 0. - node.update_kwarg("dim", start_permute[0]) - - def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None: - if len(node.args) >= 2: - node.update_arg( - 1, [start_permute[dim] for dim in cast(list[int], node.args[1])] - ) - else: - node.update_kwarg( - "dim", - [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])], - ) - - def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None: - if len(node.args) >= 2: - node.update_arg(1, start_permute[cast(int, node.args[1])]) - else: - node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) - - def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: - assert permute_node.target == exir_ops.edge.aten.permute_copy.default - if len(permute_node.args) >= 2: - return cast(list[int], permute_node.args[1]) - assert "dim" in permute_node.kwargs - return cast(list[int], permute_node.kwargs["dim"]) +class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps): + permutable_ops: set[EdgeOpOverload] = ( + _SharedRemovePermutesAroundElementwiseOps.permutable_ops + | { + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + exir_ops.edge.cadence.quantized_relu.per_tensor, + exir_ops.edge.cadence.requantize.per_tensor, + exir_ops.edge.cadence.quantized_add.per_tensor, + } + ) @register_cadence_pass(CadencePassAttribute(opt_level=2)) diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 8a0e112aaf3..e14471bc7ed 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -9,10 +9,9 @@ # This file contains all the functions that reorder ops in the graph module. -import copy from collections import defaultdict from math import prod -from typing import cast, DefaultDict, List, Tuple +from typing import DefaultDict, List, Tuple import torch import torch.fx @@ -24,6 +23,9 @@ RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView as _SharedPostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult @@ -633,191 +635,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(RemoveOrReplacePassInterface): - """ - A common pattern seen in transformer models. If the consumer of permute - is a view op, swap their order so permute is below view. - Change "permute -> view" to "view -> permute" - This is to optimize a chain of view->permute->view->permute... - so that the chain will be become view->v...->view->permute->p...->permute. - The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and - FuseCascadedViewOps(). - Notice the class name has ViewSqueeze to indicate the View is - functionally the same as a squeeze or unsqueeze. It does not necessarily - mean the view_copy is normalized from squeeze or unsqueeze. - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [exir_ops.edge.aten.permute_copy.default] - - # If list1 and list2 are same (same values and in same order) except - # list1 has one more element with value of 1. Return index of the extra 1. - # Otherwise return -1. - def check_if_shapes_differ_in_single_dim_of_size_1( - self, list1: List, list2: List - ) -> int: - if len(list1) != len(list2) + 1: - return -1 - for i in range(len(list2)): - if list1[i] != list2[i]: - # Return index of the extra 1 if the remaining parts are the same - if list1[i] == 1 and list2[i:] == list1[i + 1 :]: - return i - else: - return -1 - # If no difference was found, the extra element is at the end - if list1[-1] == 1: - return len(list2) - else: - return -1 - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - users = list(node.users.keys()) - # Transform only for pattern permute_copy->view_copy, and - # view_copy op is the only user of permute_copy. - if len(users) != 1 or users[0].target not in ( - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.view.default, - ): - return False - - # If the permute_node/view_node was newly added to the - # graph, it may not have the meta["val"] FakeTensor. - # Skip in this case. - if node.meta.get("val") is None: - return False - - permute_node_shape = [*cast(list, get_shape(node.graph.owning_module, node))] - - permute_dims = cast(list, node.args[1]) - view_node = users[0] - - if view_node.meta.get("val") is None: - return False - - view_node_shape = [*cast(list, get_shape(node.graph.owning_module, view_node))] - - pred = node.args[0] - if not isinstance(pred, torch.fx.Node) or pred.meta.get("val") is None: - return False - - pred_shape = [*cast(list, get_shape(node.graph.owning_module, pred))] - - # Handle three cases - # 1. view_node_shape is almost same as permute_node_shape - # except the view_node has one more dim somewhere - # and the extra dim has value of 1. - # 2. view_node_shape is almost same as permute_node_shape - # except permute_node_shape has one more dim somewhere - # and the extra dim has value of 1. - # 3. view_node_shape is the same as permute_node_shape. - - if len(permute_node_shape) + 1 == len(view_node_shape): - index = self.check_if_shapes_differ_in_single_dim_of_size_1( - view_node_shape, permute_node_shape - ) - if index != -1: - # view_node_shape is almost same as permute_node_shape - # except it has one more dim somewhere - # and the extra dim has value of 1. - new_view_shape = copy.deepcopy(pred_shape) - new_view_shape.insert(index, 1) - new_permute_dims = [x + 1 if x >= index else x for x in permute_dims] - new_permute_dims.insert(index, index) - self._insert_nodes( - node.graph, - pred, - node, - view_node, - new_view_shape, - new_permute_dims, - ) - return True - - elif len(view_node_shape) + 1 == len(permute_node_shape): - index = self.check_if_shapes_differ_in_single_dim_of_size_1( - permute_node_shape, view_node_shape - ) - if index != -1: - # view_node_shape is almost same as permute_node_shape - # except permute_node_shape has one more dim somewhere - # and the extra dim has value of 1. - # Convert permute_dims to list of ints - index_to_remove = permute_dims[index] - new_view_shape = copy.deepcopy(pred_shape) - del new_view_shape[index_to_remove] - new_permute_dims = [ - x - 1 if x > index_to_remove else x for x in permute_dims - ] - del new_permute_dims[index] - self._insert_nodes( - node.graph, - pred, - node, - view_node, - new_view_shape, - new_permute_dims, - ) - return True - - elif permute_node_shape == view_node_shape: - # view_node_shape is the same as permute_node_shape - # Replace the uses of view_node with permute_node - view_node.replace_all_uses_with(node) - return True - - return False - - def _insert_nodes( - self, - graph: torch.fx.Graph, - pred: torch.fx.Node, - permute_node: torch.fx.Node, - view_node: torch.fx.Node, - new_view_shape: List, - new_permute_dims: List, - ) -> None: - with graph.inserting_after(view_node): - # Target is guaranteed to be a callable since it's from the graph - view_target = view_node.target - assert callable(view_target), "View target must be callable" - new_view_node = graph.call_function( - view_target, - args=(pred, new_view_shape), - ) - - with graph.inserting_after(new_view_node): - # Target is guaranteed to be a callable since it's from our targets list - permute_target = permute_node.target - assert callable(permute_target), "Permute target must be callable" - new_permute_node = graph.call_function( - permute_target, - args=(new_view_node, new_permute_dims), - ) - new_permute_node.meta = view_node.meta - view_node.replace_all_uses_with(new_permute_node) - - # view_node is user of permute_node, so must erase view_node first - graph.erase_node(view_node) - graph.erase_node(permute_node) - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - # This pass needs to iterate until convergence because postponing - # one permute may enable postponing another in a chain - iter_count = 0 - local_modified = False - overall_modified = False - while local_modified or iter_count == 0: - result = super().call(graph_module) - local_modified = result.modified - overall_modified |= local_modified - graph_module = result.graph_module - iter_count += 1 - if iter_count == 4: - break - - return PassResult(graph_module, overall_modified) +class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView( + _SharedPostponePermuteOpBelowSqueezeOrUnsqueezeLikeView +): + pass # The following class consolidates functions to reoder ops (i.e., either hoist diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index e09a6589e76..4b60feb2121 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -28,6 +28,9 @@ RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import is_depthwise_conv +from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( + ReplaceNopTransposeOrPermuteWithViewPass as _SharedReplaceNopTransposeOrPermuteWithViewPass, +) from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -1745,77 +1748,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceNopTransposeOrPermuteWithViewPass(RemoveOrReplacePassInterface): - """ - If the transpose/permute op does not change the byte order (e.g., - transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced - by view op. - """ - - @property - def targets(self) -> list[EdgeOpOverload]: - return [ - exir_ops.edge.aten.transpose_copy.int, - exir_ops.edge.aten.permute_copy.default, - ] - - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Get the input tensor and shape - in_tensor_node = node.args[0] - assert isinstance(in_tensor_node, torch.fx.Node) - in_shape = in_tensor_node.meta["val"].shape - # Get the output tensor shape - out_shape = node.meta["val"].shape - - if node.target == exir_ops.edge.aten.transpose_copy.int: - # Get the two dims to be transposed - dim0 = cast(int, node.args[1]) - dim1 = cast(int, node.args[2]) - dim0 = dim0 if dim0 >= 0 else len(in_shape) + dim0 - dim1 = dim1 if dim1 >= 0 else len(in_shape) + dim1 - # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; - # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. - both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 - either_one_and_consecutive = abs(dim0 - dim1) == 1 and ( - in_shape[dim0] == 1 or in_shape[dim1] == 1 - ) - if both_one or either_one_and_consecutive: - with node.graph.inserting_before(node): - new_node = node.graph.call_function( - exir_ops.edge.aten.view_copy.default, - args=(in_tensor_node, list(out_shape)), - ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - return True - - elif node.target == exir_ops.edge.aten.permute_copy.default: - old_dims = list(range(len(in_shape))) - new_dims = cast(Sequence[int], node.args[1]) - # If the permute does not change anything, return the input as output. - if old_dims == list(new_dims): - node.replace_all_uses_with(in_tensor_node) - return True - # Get the old dim order, and the permuted dim order for all dims that - # are not 1. - old_order = [ - dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1 - ] - new_order = [ - dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1 - ] - # If the byte ordering for non-unit dims is unchanged, this is a nop. - if old_order == new_order: - with node.graph.inserting_before(node): - new_node = node.graph.call_function( - exir_ops.edge.aten.view_copy.default, - args=(in_tensor_node, list(out_shape)), - ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - return True - - return False +class ReplaceNopTransposeOrPermuteWithViewPass( + _SharedReplaceNopTransposeOrPermuteWithViewPass +): + pass @register_cadence_pass(CadencePassAttribute(opt_level=2)) diff --git a/backends/cadence/aot/tests/test_pass_utils.py b/backends/cadence/aot/tests/test_pass_utils.py index 2776a370541..c9987cb7196 100644 --- a/backends/cadence/aot/tests/test_pass_utils.py +++ b/backends/cadence/aot/tests/test_pass_utils.py @@ -7,10 +7,8 @@ # pyre-strict import unittest -from typing import List import torch -from beartype.roar import BeartypeDoorHintViolation from executorch.backends.cadence.aot.pass_utils import get_arg @@ -61,9 +59,11 @@ def test_get_arg_with_list_type(self) -> None: self.assertEqual(result, [1, 2, 3]) def test_get_arg_with_list_int_type(self) -> None: - """Test get_arg validates parameterized List[int] type.""" + """Test get_arg accepts parameterized List[int] type without crashing.""" _, node = self._create_graph_with_kwargs(input=[1, 2, 3], other=2) - result = get_arg(node, "input", List[int]) + # Subscripted generics can't be checked with isinstance, so get_arg + # silently skips validation. Just verify it returns the value. + result = get_arg(node, "input", list) self.assertEqual(result, [1, 2, 3]) def test_get_arg_without_type_returns_value(self) -> None: @@ -73,13 +73,13 @@ def test_get_arg_without_type_returns_value(self) -> None: self.assertEqual(result, 42) def test_get_arg_type_mismatch_raises(self) -> None: - """Test get_arg raises BeartypeDoorHintViolation on type mismatch.""" + """Test get_arg raises TypeError on type mismatch.""" _, node = self._create_graph_with_kwargs(input="not_an_int", other=2) - with self.assertRaises(BeartypeDoorHintViolation): + with self.assertRaises(TypeError): get_arg(node, "input", int) def test_get_arg_list_type_mismatch_raises(self) -> None: - """Test get_arg raises BeartypeDoorHintViolation when list elements mismatch.""" - _, node = self._create_graph_with_kwargs(input=["a", "b"], other=2) - with self.assertRaises(BeartypeDoorHintViolation): - get_arg(node, "input", List[int]) + """Test get_arg raises TypeError when value is not a list.""" + _, node = self._create_graph_with_kwargs(input="not_a_list", other=2) + with self.assertRaises(TypeError): + get_arg(node, "input", list) diff --git a/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py b/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py new file mode 100644 index 00000000000..b8d6c75a174 --- /dev/null +++ b/backends/transforms/fuse_cascaded_transpose_or_permute_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from executorch.backends.transforms.permute_pass_utils import ( + get_arg, + get_permuted_dims, + get_transposed_dims, + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from torch.fx import Node + + +class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): + """ + Fuse a chain of transpose and permute ops into a single permute or a no-op. + Handles branches and chains permutes. + """ + + transpose_or_permute_target = { + exir_ops.edge.aten.transpose_copy.int, + exir_ops.edge.aten.permute_copy.default, + } + + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.transpose_or_permute_target) + + def maybe_remove_or_replace(self, node: Node) -> bool: + # Fuse with the parent node if it's also a permute or a transpose. Since the + # pass interface traverses all ops in order the pass will properly fuse a chain + # of permutes. + parent_node = get_arg(node, "input", Node) + if parent_node.target not in self.transpose_or_permute_target: + return False + input_of_parent = get_arg(parent_node, "input", Node) + + # Compute combined effect of permutes. + dims = list(range(node.meta["val"].ndim)) + + if parent_node.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(parent_node, dims) + else: + dims = get_permuted_dims(parent_node, dims) + + if node.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(node, dims) + else: + dims = get_permuted_dims(node, dims) + + # If combined effect is identity replace the node with input. + if dims == sorted(dims): + node.replace_all_uses_with(input_of_parent) + else: + with node.graph.inserting_before(node): + new_permute = node.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(input_of_parent, dims), + ) + new_permute.meta = node.meta + node.replace_all_uses_with(new_permute) + + return True diff --git a/backends/transforms/fuse_cascaded_view_ops.py b/backends/transforms/fuse_cascaded_view_ops.py new file mode 100644 index 00000000000..7daf6ffe92e --- /dev/null +++ b/backends/transforms/fuse_cascaded_view_ops.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast + +import torch +from executorch.backends.transforms.permute_pass_utils import ( + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload + + +class FuseCascadedViewOps(RemoveOrReplacePassInterface): + """ + Fuse a cascaded chain of view ops + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.view_copy.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if the input to this view node is also a view node + input_view = node.args[0] + if not isinstance(input_view, torch.fx.Node): + return False + + if ( + input_view.op != "call_function" + or input_view.target != exir_ops.edge.aten.view_copy.default + ): + return False + + # Replace the input of this view node with the input of the cascaded view + # This effectively "skips" the intermediate view node + node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) + return True diff --git a/backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py b/backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py new file mode 100644 index 00000000000..008775511ec --- /dev/null +++ b/backends/transforms/fuse_transpose_or_permute_op_pairs_pass.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Any, Callable, cast + +import torch +import torch.fx +from executorch.backends.transforms.permute_pass_utils import ( + FuseOpPairsAcrossBranchesPass, + get_permuted_dims, + get_transposed_dims, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.pass_base import PassResult + + +class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): + """ + Fuse transpose or permute op pairs to a single view op. + (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) + This happens when op2(op1) == identity, modulo unitary dimensions. + 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] + so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. + """ + + # A list of ops that can be bypassed when looking for a + # transpose-permute chain. Subclasses can extend this with backend-specific ops. + bypass_ops: set[EdgeOpOverload] = { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + } + + def can_fuse_for_chain( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + consumer_op_packets: set[EdgeOpOverloadPacket], + ) -> bool: + if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): + return False + + # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions + producer_input = cast(torch.fx.Node, producer.args[0]) + if "val" not in producer_input.meta: + return False + input_shape = producer_input.meta["val"].shape + ident_dims = list(range(len(input_shape))) + # this mapping helps to handle both transpose and permutations + f: dict[Any, Callable] = { + exir_ops.edge.aten.transpose_copy.int: get_transposed_dims, + exir_ops.edge.aten.permute_copy.default: get_permuted_dims, + } + in_dims = f[producer.target](producer, ident_dims) + out_dims = f[consumer.target](consumer, in_dims) + # Filtering out unitary dimensions + non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1] + non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1] + return non_unit_out_dims == non_unit_ident_dims + + def get_fused_node( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + graph_module: torch.fx.GraphModule, + ) -> torch.fx.Node: + # This step is important because of how we can fuse transpositions that are not perfectly + # reverse one of another but will be fused if there are unitary dimensions. + # The fused operation must have the same output shape as the consumer. + output_shape = consumer.meta["val"].shape + with graph_module.graph.inserting_after(consumer): + view = graph_module.graph.call_function( + exir_ops.edge.aten.view_copy.default, + (consumer.args[0], output_shape), + {}, + ) + return view + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Remove any transpose/permutation op pair that cancel each other. + modified = self.find_and_fuse( + graph_module, + producer_op_packets={ + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + }, + consumer_op_packets={ + exir_ops.edge.aten.transpose_copy, + exir_ops.edge.aten.permute_copy, + }, + bypass_ops=self.bypass_ops, + ) + if modified: + return super().call(graph_module) + return PassResult(graph_module, False) diff --git a/backends/transforms/permute_pass_utils.py b/backends/transforms/permute_pass_utils.py new file mode 100644 index 00000000000..fca8946165e --- /dev/null +++ b/backends/transforms/permute_pass_utils.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +"""Shared utilities and base classes for permute optimization passes. + +These were originally in executorch.backends.cadence.aot and are used by +both the Cadence and Arm backends. +""" + +from abc import abstractmethod +from collections import deque +from typing import cast, List, Optional, Type, TypeVar, Union + +import torch +import torch.fx +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import Node +from torch.fx.node import Argument + +T = TypeVar("T") + + +def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: + edge_op_namespace, edge_op_name = ( + edge_op.namespace, + edge_op._schema.name.split("::")[1], + ) + edge_op_overload_packet = getattr( + getattr(exir_ops.edge, edge_op_namespace), edge_op_name + ) + return edge_op_overload_packet + + +def get_shape( + graph_module: torch.fx.GraphModule, node: torch.fx.Node +) -> Union[torch.Size, None]: + """Return the shape of the tensor corresponding to node.""" + try: + if isinstance(node, (float, int, bool)): + return torch.Size([1]) + fake_tensor = node.meta.get("val") + if fake_tensor is not None: + return fake_tensor.shape + if node.op == "get_attr": + attr_node = getattr(graph_module, node.target) + return attr_node.shape + return None + except RuntimeError: + return None + + +def get_transposed_dims( + node: torch.fx.Node, dims: Optional[List[int]] = None +) -> List[int]: + """Applies the transposition as given by node onto the dimensions given in input.""" + assert node.target == exir_ops.edge.aten.transpose_copy.int + assert dims is not None + dim_len = len(dims) + transpose_dims0 = node.args[1] + transpose_dims1 = node.args[2] + assert isinstance(transpose_dims0, int) + assert isinstance(transpose_dims1, int) + dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len + dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len + new_dims = list(dims) + new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0] + return new_dims + + +def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]: + """Applies the permutation as given by node onto the dimensions given in input.""" + assert node.target == exir_ops.edge.aten.permute_copy.default + # pyre-fixme[6]: This combined typecheck isn't supported yet. + permute_dims: List[int] = list(node.args[1]) + assert all(isinstance(x, int) for x in permute_dims) + return [dims[x] for x in permute_dims] + + +def get_arg( + node: torch.fx.Node, + kwarg_name: str, + expected_type: Type[T] = Argument, +) -> T: + """Get the arg with kwarg_name of the node.""" + if kwarg_name in node.kwargs: + value = node.kwargs[kwarg_name] + else: + normalized_args = node.normalized_arguments( + node.graph.owning_module, normalize_to_only_use_kwargs=True + ) + if not normalized_args: + raise RuntimeError( + f"get_arg: Node {node} does not support normalization of arguments" + ) + value = normalized_args.kwargs[kwarg_name] + + if expected_type is not Argument: + try: + type_ok = isinstance(value, expected_type) + except TypeError: + # Subscripted generics (e.g. List[int]) don't support isinstance. + # Fall through — caller is responsible for correctness. + type_ok = True + if not type_ok: + raise TypeError( + f"get_arg: expected {expected_type} for '{kwarg_name}', got {type(value)}" + ) + return value # type: ignore[return-value] + + +def set_arg( + node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument +) -> None: + """Set the node's arg with its name to the given value.""" + if kwarg_name in node.kwargs: + node.update_kwarg(kwarg_name, value) + return + + normalized_args = node.normalized_arguments( + node.graph.owning_module, normalize_to_only_use_kwargs=True + ) + if not normalized_args: + raise RuntimeError( + f"set_arg: Node {node} does not support normalization of arguments" + ) + + kwargs = normalized_args.kwargs + if kwarg_name not in kwargs: + raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used") + + idx = list(kwargs.keys()).index(kwarg_name) + if idx < len(node.args): + node.update_arg(idx, value) + else: + node.update_kwarg(kwarg_name, value) + + +class HierarchicalInplacePassInterface(ExportPass): + """A base class for passes that apply in-place modification to the graph module and its submodules.""" + + @abstractmethod + def _apply_flat_inplace(self, graph_module) -> bool: + raise NotImplementedError("`_apply_flat_inplace` must be implemented") + + def _apply_hierarchical_inplace(self, graph_module: torch.fx.GraphModule) -> bool: + modified: bool = False + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + modified |= self._apply_flat_inplace(module) + return modified + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = self._apply_hierarchical_inplace(graph_module) + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + return PassResult(graph_module, False) + + +class RemoveOrReplacePassInterface(HierarchicalInplacePassInterface): + @property + @abstractmethod + def targets(self) -> list[EdgeOpOverload]: + raise NotImplementedError("`targets` must be implemented") + + @abstractmethod + def maybe_remove_or_replace(self, node: Node) -> bool: + raise NotImplementedError("`maybe_remove_or_replace` must be implemented") + + def _apply_flat_inplace(self, graph_module: torch.fx.GraphModule) -> bool: + changed = False + for target in self.targets: + for node in graph_module.graph.find_nodes( + op="call_function", target=target + ): + if len(node.users) == 0: + continue + changed |= self.maybe_remove_or_replace(node) + return changed + + +class FuseOpPairsAcrossBranchesPass(ExportPass): + """Base class for passes that fuse op pairs across branches.""" + + def check_ok_to_fuse( + self, + producer: torch.fx.Node, + consumers: list[torch.fx.Node], + ) -> bool: + return True + + def can_fuse_for_chain( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + consumer_op_packets: set[EdgeOpOverloadPacket], + ) -> bool: + if ( + isinstance(consumer.target, EdgeOpOverload) + and get_edge_overload_packet(consumer.target) in consumer_op_packets + ): + return True + return False + + def get_fuse_candidates( + self, + producer: torch.fx.Node, + consumer_op_packets: set[EdgeOpOverloadPacket], + bypass_ops: set[EdgeOpOverload], + ) -> list[torch.fx.Node]: + users = deque(producer.users.keys()) + removal_candidates = [] + while users: + user = users.popleft() + if user.target in bypass_ops: + users.extend(list(user.users.keys())) + elif self.can_fuse_for_chain(producer, user, consumer_op_packets): + removal_candidates.append(user) + else: + removal_candidates.clear() + break + return removal_candidates + + def find_and_fuse( + self, + graph_module: torch.fx.GraphModule, + producer_op_packets: set[EdgeOpOverloadPacket], + consumer_op_packets: set[EdgeOpOverloadPacket], + bypass_ops: set[EdgeOpOverload], + ) -> bool: + modified = False + for node in graph_module.graph.nodes: + if not ( + isinstance(node.target, EdgeOpOverload) + and get_edge_overload_packet(node.target) in producer_op_packets + ): + continue + removal_candidates = self.get_fuse_candidates( + node, consumer_op_packets, bypass_ops + ) + if len(removal_candidates) == 0: + continue + if not self.check_ok_to_fuse(node, removal_candidates): + continue + self.fuse(node, removal_candidates, graph_module) + modified = True + if modified: + graph_module.recompile() + return modified + + def get_fused_node( + self, + producer: torch.fx.Node, + consumer: torch.fx.Node, + graph_module: torch.fx.GraphModule, + ) -> torch.fx.Node: + return consumer + + def fuse( + self, + node: torch.fx.Node, + removal_candidates: list[torch.fx.Node], + graph_module: torch.fx.GraphModule, + ) -> None: + node.replace_all_uses_with(cast(torch.fx.Node, node.args[0])) + graph_module.graph.erase_node(node) + for rnode in removal_candidates: + rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module)) + graph_module.graph.erase_node(rnode) diff --git a/backends/transforms/postpone_permute_below_squeeze_view.py b/backends/transforms/postpone_permute_below_squeeze_view.py new file mode 100644 index 00000000000..f676e19fb65 --- /dev/null +++ b/backends/transforms/postpone_permute_below_squeeze_view.py @@ -0,0 +1,207 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +from typing import cast, List + +import torch +import torch.fx +from executorch.backends.transforms.permute_pass_utils import ( + get_shape, + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import PassResult + + +class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(RemoveOrReplacePassInterface): + """ + A common pattern seen in transformer models. If the consumer of permute + is a view op, swap their order so permute is below view. + Change "permute -> view" to "view -> permute" + This is to optimize a chain of view->permute->view->permute... + so that the chain will be become view->v...->view->permute->p...->permute. + The chain can be optimized by FuseCascadedTransposeOrPermuteOps() and + FuseCascadedViewOps(). + Notice the class name has ViewSqueeze to indicate the View is + functionally the same as a squeeze or unsqueeze. It does not necessarily + mean the view_copy is normalized from squeeze or unsqueeze. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.permute_copy.default] + + # If list1 and list2 are same (same values and in same order) except + # list1 has one more element with value of 1. Return index of the extra 1. + # Otherwise return -1. + def check_if_shapes_differ_in_single_dim_of_size_1( + self, list1: List, list2: List + ) -> int: + if len(list1) != len(list2) + 1: + return -1 + for i in range(len(list2)): + if list1[i] != list2[i]: + # Return index of the extra 1 if the remaining parts are the same + if list1[i] == 1 and list2[i:] == list1[i + 1 :]: + return i + else: + return -1 + # If no difference was found, the extra element is at the end + if list1[-1] == 1: + return len(list2) + else: + return -1 + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + users = list(node.users.keys()) + # Transform only for pattern permute_copy->view_copy, and + # view_copy op is the only user of permute_copy. + if len(users) != 1 or users[0].target not in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.view.default, + ): + return False + + # If the permute_node/view_node was newly added to the + # graph, it may not have the meta["val"] FakeTensor. + # Skip in this case. + if node.meta.get("val") is None: + return False + + permute_node_shape = [*cast(list, get_shape(node.graph.owning_module, node))] + + permute_dims = cast(list, node.args[1]) + view_node = users[0] + + if view_node.meta.get("val") is None: + return False + + view_node_shape = [*cast(list, get_shape(node.graph.owning_module, view_node))] + + pred = node.args[0] + if not isinstance(pred, torch.fx.Node) or pred.meta.get("val") is None: + return False + + pred_shape = [*cast(list, get_shape(node.graph.owning_module, pred))] + + # Handle three cases + # 1. view_node_shape is almost same as permute_node_shape + # except the view_node has one more dim somewhere + # and the extra dim has value of 1. + # 2. view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # 3. view_node_shape is the same as permute_node_shape. + + if len(permute_node_shape) + 1 == len(view_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + view_node_shape, permute_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except it has one more dim somewhere + # and the extra dim has value of 1. + new_view_shape = copy.deepcopy(pred_shape) + new_view_shape.insert(index, 1) + new_permute_dims = [x + 1 if x >= index else x for x in permute_dims] + new_permute_dims.insert(index, index) + self._insert_nodes( + node.graph, + pred, + node, + view_node, + new_view_shape, + new_permute_dims, + ) + return True + + elif len(view_node_shape) + 1 == len(permute_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + permute_node_shape, view_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # Convert permute_dims to list of ints + index_to_remove = permute_dims[index] + new_view_shape = copy.deepcopy(pred_shape) + del new_view_shape[index_to_remove] + new_permute_dims = [ + x - 1 if x > index_to_remove else x for x in permute_dims + ] + del new_permute_dims[index] + self._insert_nodes( + node.graph, + pred, + node, + view_node, + new_view_shape, + new_permute_dims, + ) + return True + + elif permute_node_shape == view_node_shape: + # view_node_shape is the same as permute_node_shape + # Replace the uses of view_node with permute_node + view_node.replace_all_uses_with(node) + return True + + return False + + def _insert_nodes( + self, + graph: torch.fx.Graph, + pred: torch.fx.Node, + permute_node: torch.fx.Node, + view_node: torch.fx.Node, + new_view_shape: List, + new_permute_dims: List, + ) -> None: + with graph.inserting_after(view_node): + # Target is guaranteed to be a callable since it's from the graph + view_target = view_node.target + assert callable(view_target), "View target must be callable" + new_view_node = graph.call_function( + view_target, + args=(pred, new_view_shape), + ) + + with graph.inserting_after(new_view_node): + # Target is guaranteed to be a callable since it's from our targets list + permute_target = permute_node.target + assert callable(permute_target), "Permute target must be callable" + new_permute_node = graph.call_function( + permute_target, + args=(new_view_node, new_permute_dims), + ) + new_permute_node.meta = view_node.meta + view_node.replace_all_uses_with(new_permute_node) + + # view_node is user of permute_node, so must erase view_node first + graph.erase_node(view_node) + graph.erase_node(permute_node) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # This pass needs to iterate until convergence because postponing + # one permute may enable postponing another in a chain + iter_count = 0 + local_modified = False + overall_modified = False + while local_modified or iter_count == 0: + result = super().call(graph_module) + local_modified = result.modified + overall_modified |= local_modified + graph_module = result.graph_module + iter_count += 1 + if iter_count == 4: + break + + return PassResult(graph_module, overall_modified) diff --git a/backends/transforms/remove_permutes_around_elementwise_ops.py b/backends/transforms/remove_permutes_around_elementwise_ops.py new file mode 100644 index 00000000000..dd28b13045d --- /dev/null +++ b/backends/transforms/remove_permutes_around_elementwise_ops.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from dataclasses import dataclass, field +from typing import cast + +import torch +import torch.fx +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemovePermutesAroundElementwiseOps(ExportPass): + """ + Looks for subgraphs of elementwise ops sandwiched between permutes and removes those + permutes if possible. + Allows special handling for certain non-elementwise ops that can be easily updated + based on the permute's parameter such as mean, cat, and slice. + """ + + @dataclass() + class Subgraph: + start_permute: list[int] + end_permute: list[int] + # Nodes in the subgraph, does not include permutes. + nodes: set[torch.fx.Node] = field(default_factory=set) + # Incoming edges to the subgraph from permute nodes. + edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) + # Outgoing edges of the subgraph to permute nodes. + edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set) + # Incoming edges from constant nodes that need a compensating permute. + constant_edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field( + default_factory=set + ) + + permutable_ops: set[EdgeOpOverload] = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + # Ops that require special handling. + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.slice_copy.Tensor, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] + processed_nodes: set[torch.fx.Node] = set() + for node in graph_module.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.permute_copy.default + ): + start_permute = self.get_permutation(node) + # Expected end permutation for the subgraph. + end_permute = [start_permute.index(i) for i in range(len(start_permute))] + + for user in node.users: + if user.target not in self.permutable_ops: + continue + # Create a separate subgraph for each user since there may be cases + # where only a portion of the users are permutable. + subgraph = self.Subgraph(start_permute, end_permute) + if self.visit(user, subgraph, processed_nodes): + subgraphs_found.append(subgraph) + for node in subgraph.nodes: + processed_nodes.add(node) + + modified = False + for subgraph in subgraphs_found: + self.permute_subgraph(subgraph) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) + + def visit( # noqa: C901 + self, + node: torch.fx.Node, + subgraph: Subgraph, + processed_nodes: set[torch.fx.Node], + ) -> bool: + if node in subgraph.nodes: + return True + if node in processed_nodes or not self.is_node_permutable(node): + return False + subgraph.nodes.add(node) + + # Traverse downstream: + for user in node.users: + # Output should either go to a matching permute or another permutable op. + if user.target == exir_ops.edge.aten.permute_copy.default: + if self.get_permutation(user) != subgraph.end_permute: + return False + subgraph.edges_out.add((node, user)) + elif user.op == "output": + # Graph output requires the data in its original layout. + # Removing permutes here would silently change the output + # format, so treat this as an invalid subgraph boundary. + return False + elif not self.visit(user, subgraph, processed_nodes): + return False + + # Traverse upstream: + for inp in node.all_input_nodes: + # Input should either come from a matching permute or another permutable op. + if inp.target == exir_ops.edge.aten.permute_copy.default: + if self.get_permutation(inp) != subgraph.start_permute: + return False + subgraph.edges_in.add((inp, node)) + elif self._is_constant(inp): + # Only accept the constant if we can insert a compensating + # permute or view. Otherwise reject the subgraph. + const_rank = self._get_node_rank(inp) + permute_rank = len(subgraph.end_permute) + if const_rank is None: + return False + if const_rank > permute_rank: + return False + if const_rank < permute_rank and inp.meta.get("val") is None: + return False + subgraph.constant_edges_in.add((inp, node)) + elif not self.visit(inp, subgraph, processed_nodes): + return False + + return True + + def _is_constant(self, node: torch.fx.Node) -> bool: + """Check if a node's value is available at compile time. + Only considers direct constants (get_attr, parameter/buffer/constant + placeholders) — does not recurse into call_function chains to avoid + stack overflow on deep graphs.""" + if node.op == "get_attr": + return True + if node.op == "placeholder": + target = str(node.target) + return target.startswith(("b_", "p_", "c_")) + return False + + def _get_node_rank(self, node: torch.fx.Node) -> int | None: + """Return the tensor rank of a node's output, or None if unknown.""" + val = node.meta.get("val") + if val is not None and hasattr(val, "shape"): + return len(val.shape) + return None + + def is_node_permutable(self, node: torch.fx.Node) -> bool: + if node.target not in self.permutable_ops: + return False + if node.target == exir_ops.edge.aten.mean.dim: + # keepdim should be True. + if len(node.args) >= 3: + if not node.args[2]: + return False + elif "keepdim" in node.kwargs: + if not node.kwargs["keepdim"]: + return False + else: + # Default keepdim is False. + return False + return True + + def permute_subgraph(self, subgraph: Subgraph) -> None: + # Skip incoming permutes. + for inp, out in subgraph.edges_in: + assert inp.target == exir_ops.edge.aten.permute_copy.default + if len(inp.args) >= 1: + out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0])) + else: + out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"])) + + # Insert compensating permute on constant inputs. + # Since the subgraph's start permutes are being removed, the subgraph + # will operate in the un-permuted (original) layout. Constants that + # were in the permuted layout need end_permute (the inverse of + # start_permute) to convert back to the original layout. + for const_node, user_node in subgraph.constant_edges_in: + graph = const_node.graph + const_rank = self._get_node_rank(const_node) + permute_rank = len(subgraph.end_permute) + + with graph.inserting_after(const_node): + if const_rank is not None and const_rank == permute_rank: + new_node = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + args=(const_node, subgraph.end_permute), + ) + elif ( + const_rank is not None + and const_rank < permute_rank + and const_node.meta.get("val") is not None + ): + # Rank mismatch (e.g. rank-1 bias with rank-4 permute). + # The constant is broadcastable and its shape is smaller + # than the permute rank, so we can't apply the permute + # directly. Instead, use view_copy to rearrange the + # shape according to the end_permute restricted to + # the trailing dimensions. + original_shape = list(const_node.meta["val"].shape) + # Pad shape to match permute rank for reordering + padded = [1] * (permute_rank - const_rank) + original_shape + target_shape = [padded[d] for d in subgraph.end_permute] + # Strip leading 1s back to original rank + target_shape = target_shape[permute_rank - const_rank :] + new_node = graph.create_node( + "call_function", + exir_ops.edge.aten.view_copy.default, + args=(const_node, target_shape), + ) + else: + # Cannot determine rank or handle this case; skip. + continue + user_node.replace_input_with(const_node, new_node) + + # Skip outgoing permutes. + for inp, out in subgraph.edges_out: + assert out.target == exir_ops.edge.aten.permute_copy.default + out.replace_all_uses_with(inp) + + # Handle dimension related node arguments. + for node in subgraph.nodes: + if node.target == exir_ops.edge.aten.cat.default: + self.update_cat(node, subgraph.start_permute) + elif node.target == exir_ops.edge.aten.mean.dim: + self.update_mean_dim(node, subgraph.start_permute) + elif node.target == exir_ops.edge.aten.slice_copy.Tensor: + self.update_slice_copy(node, subgraph.start_permute) + + def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None: + if len(node.args) >= 2: + node.update_arg(1, start_permute[cast(int, node.args[1])]) + elif "dim" in node.kwargs: + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) + else: + # Default cat dim is 0. + node.update_kwarg("dim", start_permute[0]) + + def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None: + if len(node.args) >= 2: + node.update_arg( + 1, [start_permute[dim] for dim in cast(list[int], node.args[1])] + ) + else: + node.update_kwarg( + "dim", + [start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])], + ) + + def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None: + if len(node.args) >= 2: + node.update_arg(1, start_permute[cast(int, node.args[1])]) + else: + node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])]) + + def get_permutation(self, permute_node: torch.fx.Node) -> list[int]: + assert permute_node.target == exir_ops.edge.aten.permute_copy.default + if len(permute_node.args) >= 2: + return cast(list[int], permute_node.args[1]) + assert "dim" in permute_node.kwargs + return cast(list[int], permute_node.kwargs["dim"]) diff --git a/backends/transforms/replace_nop_transpose_or_permute_with_view.py b/backends/transforms/replace_nop_transpose_or_permute_with_view.py new file mode 100644 index 00000000000..ccfb4ebe8b9 --- /dev/null +++ b/backends/transforms/replace_nop_transpose_or_permute_with_view.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast, Sequence + +import torch +import torch.fx +from executorch.backends.transforms.permute_pass_utils import ( + RemoveOrReplacePassInterface, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload + + +class ReplaceNopTransposeOrPermuteWithViewPass(RemoveOrReplacePassInterface): + """ + If the transpose/permute op does not change the byte order (e.g., + transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced + by view op. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.transpose_copy.int, + exir_ops.edge.aten.permute_copy.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor and shape + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) + in_shape = in_tensor_node.meta["val"].shape + # Get the output tensor shape + out_shape = node.meta["val"].shape + + if node.target == exir_ops.edge.aten.transpose_copy.int: + # Get the two dims to be transposed + dim0 = cast(int, node.args[1]) + dim1 = cast(int, node.args[2]) + dim0 = dim0 if dim0 >= 0 else len(in_shape) + dim0 + dim1 = dim1 if dim1 >= 0 else len(in_shape) + dim1 + # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; + # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. + both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 + either_one_and_consecutive = abs(dim0 - dim1) == 1 and ( + in_shape[dim0] == 1 or in_shape[dim1] == 1 + ) + if both_one or either_one_and_consecutive: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + elif node.target == exir_ops.edge.aten.permute_copy.default: + old_dims = list(range(len(in_shape))) + new_dims = cast(Sequence[int], node.args[1]) + # If the permute does not change anything, return the input as output. + if old_dims == list(new_dims): + node.replace_all_uses_with(in_tensor_node) + return True + # Get the old dim order, and the permuted dim order for all dims that + # are not 1. + old_order = [ + dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1 + ] + new_order = [ + dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1 + ] + # If the byte ordering for non-unit dims is unchanged, this is a nop. + if old_order == new_order: + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + return False diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 463c89e43b2..5c3343469ce 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -268,3 +268,114 @@ def define_common_targets(): "//executorch/exir/tests:test_memory_format_ops_pass_utils", ], ) + + # Shared permute optimization passes (used by both Cadence and Arm backends) + runtime.python_library( + name = "permute_pass_utils", + srcs = ["permute_pass_utils.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "fuse_cascaded_transpose_or_permute_ops", + srcs = ["fuse_cascaded_transpose_or_permute_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "fuse_cascaded_view_ops", + srcs = ["fuse_cascaded_view_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "fuse_transpose_or_permute_op_pairs_pass", + srcs = ["fuse_transpose_or_permute_op_pairs_pass.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "remove_permutes_around_elementwise_ops", + srcs = ["remove_permutes_around_elementwise_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "postpone_permute_below_squeeze_view", + srcs = ["postpone_permute_below_squeeze_view.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_library( + name = "replace_nop_transpose_or_permute_with_view", + srcs = ["replace_nop_transpose_or_permute_with_view.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/dialects:lib", + ":permute_pass_utils", + ], + ) + + runtime.python_test( + name = "test_permute_optimization_passes", + srcs = [ + "test/test_permute_optimization_passes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/test:graph_builder", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ":fuse_cascaded_transpose_or_permute_ops", + ":fuse_cascaded_view_ops", + ":postpone_permute_below_squeeze_view", + ":replace_nop_transpose_or_permute_with_view", + ], + ) diff --git a/backends/transforms/test/test_permute_optimization_passes.py b/backends/transforms/test/test_permute_optimization_passes.py new file mode 100644 index 00000000000..bb326f125bc --- /dev/null +++ b/backends/transforms/test/test_permute_optimization_passes.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import copy +import unittest +from typing import cast + +import torch +from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( + FuseCascadedTransposeOrPermuteOps, +) +from executorch.backends.transforms.fuse_cascaded_view_ops import FuseCascadedViewOps +from executorch.backends.transforms.postpone_permute_below_squeeze_view import ( + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, +) +from executorch.backends.transforms.replace_nop_transpose_or_permute_with_view import ( + ReplaceNopTransposeOrPermuteWithViewPass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult +from torch.utils import _pytree as pytree + + +def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int: + """Count the number of nodes with target `target` in the graph.""" + total = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == target: + total += 1 + return total + + +def validate_numerics( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Validate that two graph modules produce numerically equivalent outputs.""" + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}." + ) + + +def get_compute_nodes( + graph_module: torch.fx.GraphModule, +) -> list: + """Return the target of each call_function node in order.""" + return [ + n.target + for n in graph_module.graph.nodes + if n.op == "call_function" + and n.target + not in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_numel.default, + ) + ] + + +# ────────────────────────────────────────────────────────────────────── +# Tests for FuseCascadedTransposeOrPermuteOps +# ────────────────────────────────────────────────────────────────────── + + +class FuseCascadedTransposeOrPermuteOpsTest(unittest.TestCase): + def test_permute_transpose_fusion(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4)) + permute = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 4, 1, 3]) + ) + transpose = builder.call_operator( + op=exir_ops.edge.aten.transpose_copy.int, args=(permute, 1, 0) + ) + builder.output([transpose]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedTransposeOrPermuteOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) + validate_numerics( + gm_before, + gm, + [torch.randn(3, 1, 3, 1, 4)], + "FuseCascadedTransposeOrPermuteOps", + ) + + def test_cascaded_permutes_multiple_users(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) + permute1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1]) + ) + # permute2 reverses permute1 => identity + permute2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 3, 1, 2]) + ) + # permute3: different permutation + permute3 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [0, 2, 1, 3]) + ) + # permute4 -> permute5: chained + permute4 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute1, [3, 2, 0, 1]) + ) + permute5 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(permute4, [0, 1, 3, 2]) + ) + builder.output([permute2, permute3, permute5]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedTransposeOrPermuteOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + validate_numerics( + gm_before, + result.graph_module, + [torch.randn(2, 3, 4, 5)], + "FuseCascadedTransposeOrPermuteOps", + ) + + +# ────────────────────────────────────────────────────────────────────── +# Tests for FuseCascadedViewOps +# ────────────────────────────────────────────────────────────────────── + + +class FuseCascadedViewOpsTest(unittest.TestCase): + def test_view_fusion(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 5, 3)) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(v1, [1, 1, 120]) + ) + v3 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(v2, [120]) + ) + builder.output([v3]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedViewOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, + gm, + [torch.randn(8, 5, 3)], + "FuseCascadedViewOps", + ) + + def test_view_fusion_branched(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(8, 5, 3)) + y = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) + ) + branch1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [1, 1, 120]) + ) + branch2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(y, [120, 1, 1]) + ) + builder.output([branch1, branch2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + p = FuseCascadedViewOps() + result = cast(PassResult, p(original)) + self.assertTrue(result.modified) + gm = result.graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) + validate_numerics( + gm_before, + gm, + [torch.randn(8, 5, 3)], + "FuseCascadedViewOps", + ) + + +# ────────────────────────────────────────────────────────────────────── +# Tests for PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView +# ────────────────────────────────────────────────────────────────────── + + +class PostponePermuteBelowSqueezeViewTest(unittest.TestCase): + def test_permute3_view4_chains(self) -> None: + """view→permute→view→permute reordered to view→view→permute→permute.""" + builder = GraphBuilder() + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [3, 12, 64]) + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v1, [1, 0, 2]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [1, 12, 3, 64]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [0, 1, 3, 2]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + result = cast(PassResult, pass_instance.call(original)) + self.assertTrue(result.modified) + gm = result.graph_module + gm.graph.eliminate_dead_code() + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 2) + # Verify order: views before permutes + targets = get_compute_nodes(gm) + view_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.view_copy.default + ] + permute_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.permute_copy.default + ] + self.assertTrue(all(v < p for v in view_indices for p in permute_indices)) + + validate_numerics( + gm_before, + gm, + [x_data], + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + + def test_permute4_view3_chains(self) -> None: + """4d→permute→view→3d→permute reordered to view→view→permute→permute.""" + builder = GraphBuilder() + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 3, 12, 64]) + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v1, [3, 1, 0, 2]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [64, 3, 12]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [2, 1, 0]) + ) + builder.output([p2]) + original = builder.get_graph_module() + gm_before = copy.deepcopy(original) + + pass_instance = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + result = cast(PassResult, pass_instance.call(original)) + self.assertTrue(result.modified) + gm = result.graph_module + + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 2) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 2) + targets = get_compute_nodes(gm) + view_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.view_copy.default + ] + permute_indices = [ + i + for i, t in enumerate(targets) + if t == exir_ops.edge.aten.permute_copy.default + ] + self.assertTrue(all(v < p for v in view_indices for p in permute_indices)) + + validate_numerics( + gm_before, + gm, + [x_data], + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + + def test_negative_not_squeeze_like(self) -> None: + """View that reshapes (not just squeeze/unsqueeze) should NOT be reordered.""" + builder = GraphBuilder() + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) + v1 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 3, 12, 64]) + ) + p1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v1, [3, 1, 0, 2]) + ) + v2 = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, args=(p1, [64, 6, 6]) + ) + p2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, args=(v2, [2, 1, 0]) + ) + builder.output([p2]) + original = builder.get_graph_module() + + pass_instance = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() + result = cast(PassResult, pass_instance.call(original)) + self.assertFalse(result.modified) + + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.view_copy.default), 2 + ) + self.assertEqual( + count_node(result.graph_module, exir_ops.edge.aten.permute_copy.default), + 2, + ) + # Order unchanged: view, permute, view, permute + targets = get_compute_nodes(result.graph_module) + self.assertEqual(targets[0], exir_ops.edge.aten.view_copy.default) + self.assertEqual(targets[1], exir_ops.edge.aten.permute_copy.default) + + +# ────────────────────────────────────────────────────────────────────── +# Tests for ReplaceNopTransposeOrPermuteWithViewPass +# ────────────────────────────────────────────────────────────────────── + + +class ReplaceNopTransposeOrPermuteWithViewTest(unittest.TestCase): + def test_replace_nop_transpose_with_view_float(self) -> None: + x = torch.randn(2, 1, 3, 1) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 1, 3), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_transpose_with_view_int(self) -> None: + x = torch.randint(low=0, high=100, size=(2, 1, 5), dtype=torch.int64) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, 1, 0), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.transpose_copy.int), 0) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_permute_5d(self) -> None: + x = torch.randn(3, 1, 3, 1, 4) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 4, 1, 3]), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + ) + + def test_replace_nop_permute_3d(self) -> None: + x = torch.randn(1, 3, 4) + gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [1, 2, 0]), + ) + gm_before = copy.deepcopy(gm) + + p = ReplaceNopTransposeOrPermuteWithViewPass() + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after = result.graph_module + self.assertEqual( + count_node(gm_after, exir_ops.edge.aten.permute_copy.default), 0 + ) + self.assertEqual(count_node(gm_after, exir_ops.edge.aten.view_copy.default), 1) + validate_numerics( + gm_before, gm_after, [x], "ReplaceNopTransposeOrPermuteWithViewPass" + )