Skip to content

Commit 6e540da

Browse files
mcremon-metameta-codesync[bot]
authored andcommitted
Move permute optimization passes to shared transforms location (#19002)
Summary: Pull Request resolved: #19002 Move 6 permute optimization passes and their shared infrastructure from executorch/backends/cadence/aot/ to executorch/backends/transforms/ so they can be shared between the Cadence and Arm backends without a cross-backend dependency. New files: - permute_pass_utils.py: base classes (HierarchicalInplacePassInterface, RemoveOrReplacePassInterface, FuseOpPairsAcrossBranchesPass) and utilities (get_arg, set_arg, get_transposed_dims, get_permuted_dims, get_shape, get_edge_overload_packet) - fuse_cascaded_transpose_or_permute_ops.py - fuse_cascaded_view_ops.py - fuse_transpose_or_permute_op_pairs_pass.py - remove_permutes_around_elementwise_ops.py - postpone_permute_below_squeeze_view.py - replace_nop_transpose_or_permute_with_view.py The shared versions omit register_cadence_pass decorators and cadence-specific ops from default op sets. Cadence files will subclass these and re-add the decorators and ops. Added OSS tests (test_permute_optimization_passes.py) for the 4 passes that can be imported without quantized op registration: FuseCascadedTransposeOrPermuteOps, FuseCascadedViewOps, PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, and ReplaceNopTransposeOrPermuteWithViewPass. These run in GitHub CI via pytest and are discovered automatically through pytest.ini testpaths. Differential Revision: D101459577 Reviewed By: ethansfng
1 parent 8e5ec80 commit 6e540da

15 files changed

Lines changed: 1630 additions & 942 deletions

backends/cadence/aot/BUCK

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,10 @@ fbcode_target(_kind = runtime.python_library,
267267
"//caffe2:torch",
268268
"//executorch/backends/cadence/aot:pass_utils",
269269
"//executorch/backends/cadence/aot:utils",
270+
"//executorch/backends/transforms:fuse_cascaded_transpose_or_permute_ops",
271+
"//executorch/backends/transforms:fuse_cascaded_view_ops",
272+
"//executorch/backends/transforms:fuse_transpose_or_permute_op_pairs_pass",
273+
"//executorch/backends/transforms:permute_pass_utils",
270274
"//executorch/exir:pass_base",
271275
"//executorch/exir/dialects:lib",
272276
"//executorch/exir/dialects/edge:lib",
@@ -304,6 +308,7 @@ fbcode_target(_kind = runtime.python_library,
304308
"//executorch/backends/cadence/aot:pass_utils",
305309
"//executorch/backends/cadence/aot:simplify_ops",
306310
"//executorch/backends/transforms:remove_clone_ops",
311+
"//executorch/backends/transforms:remove_permutes_around_elementwise_ops",
307312
"//executorch/exir:pass_base",
308313
"//executorch/exir/dialects:lib",
309314
"//executorch/exir/dialects/edge:lib",
@@ -322,6 +327,7 @@ fbcode_target(_kind = runtime.python_library,
322327
"//executorch/backends/cadence/aot:compiler_utils",
323328
"//executorch/backends/cadence/aot:pass_utils",
324329
"//executorch/backends/cadence/aot:utils",
330+
"//executorch/backends/transforms:postpone_permute_below_squeeze_view",
325331
"//executorch/exir:pass_base",
326332
"//executorch/exir:tensor",
327333
"//executorch/exir/dialects:lib",
@@ -343,6 +349,7 @@ fbcode_target(_kind = runtime.python_library,
343349
"//executorch/backends/cadence/aot:pass_utils",
344350
"//executorch/backends/cadence/aot:remove_ops",
345351
"//executorch/backends/cadence/aot:utils",
352+
"//executorch/backends/transforms:replace_nop_transpose_or_permute_with_view",
346353
"//executorch/backends/transforms:replace_scalar_with_tensor",
347354
"//executorch/exir:pass_base",
348355
"//executorch/exir/dialects:lib",

backends/cadence/aot/fuse_ops.py

Lines changed: 18 additions & 276 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@
4141
from executorch.exir.pass_base import ExportPass, PassResult
4242
from executorch.exir.passes.cse_pass import CSEPass
4343
from torch.nn.utils.fusion import fuse_conv_bn_weights
44+
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
45+
FuseCascadedTransposeOrPermuteOps as _SharedFuseCascadedTransposeOrPermuteOps,
46+
)
47+
from executorch.backends.transforms.fuse_cascaded_view_ops import (
48+
FuseCascadedViewOps as _SharedFuseCascadedViewOps,
49+
)
50+
from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import (
51+
FuseTransposeOrPermuteOpPairsPass as _SharedFuseTransposeOrPermuteOpPairsPass,
52+
)
53+
from executorch.backends.transforms.permute_pass_utils import (
54+
FuseOpPairsAcrossBranchesPass,
55+
)
4456

4557

4658
def get_tensor_arg(node: torch.fx.Node, arg_name: str) -> torch.Tensor:
@@ -578,207 +590,14 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
578590

579591

580592
@register_cadence_pass(CadencePassAttribute(opt_level=1))
581-
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
582-
"""
583-
Fuse a chain of transpose and permute ops into a single permute or a no-op.
584-
Handles branches and chains permutes.
585-
"""
586-
587-
transpose_or_permute_target = {
588-
exir_ops.edge.aten.transpose_copy.int,
589-
exir_ops.edge.aten.permute_copy.default,
590-
}
591-
592-
@property
593-
def targets(self) -> list[EdgeOpOverload]:
594-
return list(self.transpose_or_permute_target)
595-
596-
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
597-
# Fuse with the parent node if it's also a permute or a transpose. Since the
598-
# pass interface traverses all ops in order the pass will properly fuse a chain
599-
# of permutes.
600-
parent_node = get_arg(node, "input", torch.fx.Node)
601-
if parent_node.target not in self.transpose_or_permute_target:
602-
return False
603-
input_of_parent = get_arg(parent_node, "input", torch.fx.Node)
604-
605-
# Compute combined effect of permutes.
606-
dims = list(range(node.meta["val"].ndim))
607-
608-
if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
609-
dims = get_transposed_dims(parent_node, dims)
610-
else:
611-
dims = get_permuted_dims(parent_node, dims)
612-
613-
if node.target == exir_ops.edge.aten.transpose_copy.int:
614-
dims = get_transposed_dims(node, dims)
615-
else:
616-
dims = get_permuted_dims(node, dims)
617-
618-
# If combined effect is identity replace the node with input.
619-
if dims == sorted(dims):
620-
node.replace_all_uses_with(input_of_parent)
621-
else:
622-
with node.graph.inserting_before(node):
623-
new_permute = node.graph.call_function(
624-
exir_ops.edge.aten.permute_copy.default,
625-
args=(input_of_parent, dims),
626-
)
627-
new_permute.meta = node.meta
628-
node.replace_all_uses_with(new_permute)
629-
630-
return True
593+
class FuseCascadedTransposeOrPermuteOps(_SharedFuseCascadedTransposeOrPermuteOps):
594+
pass
631595

632596

633597
@register_cadence_pass(CadencePassAttribute(opt_level=1))
634-
class FuseCascadedViewOps(RemoveOrReplacePassInterface):
635-
"""
636-
Fuse a cascaded chain of view ops
637-
"""
598+
class FuseCascadedViewOps(_SharedFuseCascadedViewOps):
599+
pass
638600

639-
@property
640-
def targets(self) -> list[EdgeOpOverload]:
641-
return [exir_ops.edge.aten.view_copy.default]
642-
643-
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
644-
# Check if the input to this view node is also a view node
645-
input_view = node.args[0]
646-
if not isinstance(input_view, torch.fx.Node):
647-
return False
648-
649-
if (
650-
input_view.op != "call_function"
651-
or input_view.target != exir_ops.edge.aten.view_copy.default
652-
):
653-
return False
654-
655-
# Replace the input of this view node with the input of the cascaded view
656-
# This effectively "skips" the intermediate view node
657-
node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0]))
658-
return True
659-
660-
661-
class FuseOpPairsAcrossBranchesPass(ExportPass):
662-
"""
663-
Base class for passes that fuse op pairs across branches.
664-
Provides common functionality for finding and fusing producer-consumer chains.
665-
"""
666-
667-
def check_ok_to_fuse(
668-
self,
669-
producer: torch.fx.Node,
670-
consumers: list[torch.fx.Node],
671-
) -> bool:
672-
# Always ok to replace / remove.
673-
return True
674-
675-
def can_fuse_for_chain(
676-
self,
677-
producer: torch.fx.Node,
678-
consumer: torch.fx.Node,
679-
consumer_op_packets: set[EdgeOpOverloadPacket],
680-
) -> bool:
681-
"""
682-
Returns true if producer and consumer can be fused for a single chain
683-
(-> producer -> ops -> consumer ->) to (-> ops -> fused_op)
684-
"""
685-
if (
686-
isinstance(consumer.target, EdgeOpOverload)
687-
and get_edge_overload_packet(consumer.target) in consumer_op_packets
688-
):
689-
return True
690-
return False
691-
692-
def get_fuse_candidates(
693-
self,
694-
producer: torch.fx.Node,
695-
consumer_op_packets: set[EdgeOpOverloadPacket],
696-
bypass_ops: set[EdgeOpOverload],
697-
) -> list[torch.fx.Node]:
698-
# Start by iterating over all the users of this node, and check
699-
# if they are have their target in consumer_op_packets.
700-
users = deque(producer.users.keys())
701-
# This holds the list of the user ops that directly (or transitively
702-
# via view/slice) consume this producer_op_packets, and hence can be removed.
703-
removal_candidates = []
704-
while users:
705-
user = users.popleft()
706-
707-
# If the user is a bypass op, we bypass it, and examine
708-
# its users instead for consumer_op_packets.
709-
if user.target in bypass_ops:
710-
users.extend(list(user.users.keys()))
711-
elif self.can_fuse_for_chain(producer, user, consumer_op_packets):
712-
removal_candidates.append(user)
713-
else:
714-
removal_candidates.clear()
715-
break
716-
return removal_candidates
717-
718-
def find_and_fuse(
719-
self,
720-
graph_module: torch.fx.GraphModule,
721-
producer_op_packets: set[EdgeOpOverloadPacket],
722-
consumer_op_packets: set[EdgeOpOverloadPacket],
723-
bypass_ops: set[EdgeOpOverload],
724-
) -> bool:
725-
"""
726-
Find and fuse producer-consumer op pairs.
727-
728-
Returns True if any fusion was performed, False otherwise.
729-
"""
730-
modified = False
731-
for node in graph_module.graph.nodes:
732-
# We are only interested in ops that have overload target in
733-
# producer_op.
734-
if not (
735-
isinstance(node.target, EdgeOpOverload)
736-
and get_edge_overload_packet(node.target) in producer_op_packets
737-
):
738-
continue
739-
740-
removal_candidates = self.get_fuse_candidates(
741-
node, consumer_op_packets, bypass_ops
742-
)
743-
744-
if len(removal_candidates) == 0:
745-
# No candidates found.
746-
continue
747-
748-
if not self.check_ok_to_fuse(node, removal_candidates):
749-
# Not ok to remove quant-dequant pairs or replace with requantize.
750-
continue
751-
752-
self.fuse(node, removal_candidates, graph_module)
753-
modified = True
754-
755-
if modified:
756-
graph_module.recompile()
757-
758-
return modified
759-
760-
def get_fused_node(
761-
self,
762-
producer: torch.fx.Node,
763-
consumer: torch.fx.Node,
764-
graph_module: torch.fx.GraphModule,
765-
) -> torch.fx.Node:
766-
return consumer
767-
768-
def fuse(
769-
self,
770-
node: torch.fx.Node,
771-
removal_candidates: list[torch.fx.Node],
772-
graph_module: torch.fx.GraphModule,
773-
) -> None:
774-
# Replace all the uses of the producer op with it's input.
775-
node.replace_all_uses_with(cast(torch.fx.Node, node.args[0]))
776-
graph_module.graph.erase_node(node)
777-
778-
# Iterate over all the removal candidates (quantize op users) and generate replacements.
779-
for rnode in removal_candidates:
780-
rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module))
781-
graph_module.graph.erase_node(rnode)
782601

783602

784603
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -1123,90 +942,13 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1123942

1124943

1125944
@register_cadence_pass(CadencePassAttribute(opt_level=1))
1126-
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
1127-
"""
1128-
Fuse transpose or permute op pairs to a single view op.
1129-
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
1130-
This happens when op2(op1) == identity, modulo unitary dimensions.
1131-
'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
1132-
so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
1133-
"""
1134-
1135-
# A list of ops that can be bypassed when looking for a
1136-
# dequantize->quantize chain
1137-
bypass_ops: set[EdgeOpOverload] = {
945+
class FuseTransposeOrPermuteOpPairsPass(_SharedFuseTransposeOrPermuteOpPairsPass):
946+
bypass_ops: set[EdgeOpOverload] = _SharedFuseTransposeOrPermuteOpPairsPass.bypass_ops | {
1138947
exir_ops.edge.cadence.quantize_per_tensor.default,
1139-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
1140-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
1141948
exir_ops.edge.cadence.dequantize_per_tensor.default,
1142-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
1143-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
1144949
exir_ops.edge.cadence.quantized_relu.per_tensor,
1145950
}
1146951

1147-
def can_fuse_for_chain(
1148-
self,
1149-
producer: torch.fx.Node,
1150-
consumer: torch.fx.Node,
1151-
consumer_op_packets: set[EdgeOpOverloadPacket],
1152-
) -> bool:
1153-
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
1154-
return False
1155-
1156-
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
1157-
producer_input = cast(torch.fx.Node, producer.args[0])
1158-
if "val" not in producer_input.meta:
1159-
return False
1160-
input_shape = producer_input.meta["val"].shape
1161-
ident_dims = list(range(len(input_shape)))
1162-
# this mapping helps to handle both transpose and permutations
1163-
f: dict[Any, Callable] = {
1164-
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
1165-
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
1166-
}
1167-
in_dims = f[producer.target](producer, ident_dims)
1168-
out_dims = f[consumer.target](consumer, in_dims)
1169-
# Filtering out unitary dimensions
1170-
non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1]
1171-
non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1]
1172-
return non_unit_out_dims == non_unit_ident_dims
1173-
1174-
def get_fused_node(
1175-
self,
1176-
producer: torch.fx.Node,
1177-
consumer: torch.fx.Node,
1178-
graph_module: torch.fx.GraphModule,
1179-
) -> torch.fx.Node:
1180-
# This step is important because of how we can fuse transpositions that are not perfectly
1181-
# reverse one of another but will be fused if there are unitary dimensions.
1182-
# The fused operation must have the same output shape as the consumer.
1183-
output_shape = consumer.meta["val"].shape
1184-
with graph_module.graph.inserting_after(consumer):
1185-
view = graph_module.graph.call_function(
1186-
exir_ops.edge.aten.view_copy.default,
1187-
(consumer.args[0], output_shape),
1188-
{},
1189-
)
1190-
return view
1191-
1192-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1193-
# Remove any transpose/permutation op pair that cancel each other.
1194-
modified = self.find_and_fuse(
1195-
graph_module,
1196-
producer_op_packets={
1197-
exir_ops.edge.aten.transpose_copy,
1198-
exir_ops.edge.aten.permute_copy,
1199-
},
1200-
consumer_op_packets={
1201-
exir_ops.edge.aten.transpose_copy,
1202-
exir_ops.edge.aten.permute_copy,
1203-
},
1204-
bypass_ops=self.bypass_ops,
1205-
)
1206-
if modified:
1207-
return super().call(graph_module)
1208-
return PassResult(graph_module, False)
1209-
1210952

1211953
@register_cadence_pass(CadencePassAttribute(opt_level=1))
1212954
class FuseFullThenReshapePass(RemoveOrReplacePassInterface):

0 commit comments

Comments
 (0)