|
41 | 41 | from executorch.exir.pass_base import ExportPass, PassResult |
42 | 42 | from executorch.exir.passes.cse_pass import CSEPass |
43 | 43 | from torch.nn.utils.fusion import fuse_conv_bn_weights |
| 44 | +from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import ( |
| 45 | + FuseCascadedTransposeOrPermuteOps as _SharedFuseCascadedTransposeOrPermuteOps, |
| 46 | +) |
| 47 | +from executorch.backends.transforms.fuse_cascaded_view_ops import ( |
| 48 | + FuseCascadedViewOps as _SharedFuseCascadedViewOps, |
| 49 | +) |
| 50 | +from executorch.backends.transforms.fuse_transpose_or_permute_op_pairs_pass import ( |
| 51 | + FuseTransposeOrPermuteOpPairsPass as _SharedFuseTransposeOrPermuteOpPairsPass, |
| 52 | +) |
| 53 | +from executorch.backends.transforms.permute_pass_utils import ( |
| 54 | + FuseOpPairsAcrossBranchesPass, |
| 55 | +) |
44 | 56 |
|
45 | 57 |
|
46 | 58 | def get_tensor_arg(node: torch.fx.Node, arg_name: str) -> torch.Tensor: |
@@ -578,207 +590,14 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
578 | 590 |
|
579 | 591 |
|
580 | 592 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
581 | | -class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): |
582 | | - """ |
583 | | - Fuse a chain of transpose and permute ops into a single permute or a no-op. |
584 | | - Handles branches and chains permutes. |
585 | | - """ |
586 | | - |
587 | | - transpose_or_permute_target = { |
588 | | - exir_ops.edge.aten.transpose_copy.int, |
589 | | - exir_ops.edge.aten.permute_copy.default, |
590 | | - } |
591 | | - |
592 | | - @property |
593 | | - def targets(self) -> list[EdgeOpOverload]: |
594 | | - return list(self.transpose_or_permute_target) |
595 | | - |
596 | | - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
597 | | - # Fuse with the parent node if it's also a permute or a transpose. Since the |
598 | | - # pass interface traverses all ops in order the pass will properly fuse a chain |
599 | | - # of permutes. |
600 | | - parent_node = get_arg(node, "input", torch.fx.Node) |
601 | | - if parent_node.target not in self.transpose_or_permute_target: |
602 | | - return False |
603 | | - input_of_parent = get_arg(parent_node, "input", torch.fx.Node) |
604 | | - |
605 | | - # Compute combined effect of permutes. |
606 | | - dims = list(range(node.meta["val"].ndim)) |
607 | | - |
608 | | - if parent_node.target == exir_ops.edge.aten.transpose_copy.int: |
609 | | - dims = get_transposed_dims(parent_node, dims) |
610 | | - else: |
611 | | - dims = get_permuted_dims(parent_node, dims) |
612 | | - |
613 | | - if node.target == exir_ops.edge.aten.transpose_copy.int: |
614 | | - dims = get_transposed_dims(node, dims) |
615 | | - else: |
616 | | - dims = get_permuted_dims(node, dims) |
617 | | - |
618 | | - # If combined effect is identity replace the node with input. |
619 | | - if dims == sorted(dims): |
620 | | - node.replace_all_uses_with(input_of_parent) |
621 | | - else: |
622 | | - with node.graph.inserting_before(node): |
623 | | - new_permute = node.graph.call_function( |
624 | | - exir_ops.edge.aten.permute_copy.default, |
625 | | - args=(input_of_parent, dims), |
626 | | - ) |
627 | | - new_permute.meta = node.meta |
628 | | - node.replace_all_uses_with(new_permute) |
629 | | - |
630 | | - return True |
| 593 | +class FuseCascadedTransposeOrPermuteOps(_SharedFuseCascadedTransposeOrPermuteOps): |
| 594 | + pass |
631 | 595 |
|
632 | 596 |
|
633 | 597 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
634 | | -class FuseCascadedViewOps(RemoveOrReplacePassInterface): |
635 | | - """ |
636 | | - Fuse a cascaded chain of view ops |
637 | | - """ |
| 598 | +class FuseCascadedViewOps(_SharedFuseCascadedViewOps): |
| 599 | + pass |
638 | 600 |
|
639 | | - @property |
640 | | - def targets(self) -> list[EdgeOpOverload]: |
641 | | - return [exir_ops.edge.aten.view_copy.default] |
642 | | - |
643 | | - def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
644 | | - # Check if the input to this view node is also a view node |
645 | | - input_view = node.args[0] |
646 | | - if not isinstance(input_view, torch.fx.Node): |
647 | | - return False |
648 | | - |
649 | | - if ( |
650 | | - input_view.op != "call_function" |
651 | | - or input_view.target != exir_ops.edge.aten.view_copy.default |
652 | | - ): |
653 | | - return False |
654 | | - |
655 | | - # Replace the input of this view node with the input of the cascaded view |
656 | | - # This effectively "skips" the intermediate view node |
657 | | - node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) |
658 | | - return True |
659 | | - |
660 | | - |
661 | | -class FuseOpPairsAcrossBranchesPass(ExportPass): |
662 | | - """ |
663 | | - Base class for passes that fuse op pairs across branches. |
664 | | - Provides common functionality for finding and fusing producer-consumer chains. |
665 | | - """ |
666 | | - |
667 | | - def check_ok_to_fuse( |
668 | | - self, |
669 | | - producer: torch.fx.Node, |
670 | | - consumers: list[torch.fx.Node], |
671 | | - ) -> bool: |
672 | | - # Always ok to replace / remove. |
673 | | - return True |
674 | | - |
675 | | - def can_fuse_for_chain( |
676 | | - self, |
677 | | - producer: torch.fx.Node, |
678 | | - consumer: torch.fx.Node, |
679 | | - consumer_op_packets: set[EdgeOpOverloadPacket], |
680 | | - ) -> bool: |
681 | | - """ |
682 | | - Returns true if producer and consumer can be fused for a single chain |
683 | | - (-> producer -> ops -> consumer ->) to (-> ops -> fused_op) |
684 | | - """ |
685 | | - if ( |
686 | | - isinstance(consumer.target, EdgeOpOverload) |
687 | | - and get_edge_overload_packet(consumer.target) in consumer_op_packets |
688 | | - ): |
689 | | - return True |
690 | | - return False |
691 | | - |
692 | | - def get_fuse_candidates( |
693 | | - self, |
694 | | - producer: torch.fx.Node, |
695 | | - consumer_op_packets: set[EdgeOpOverloadPacket], |
696 | | - bypass_ops: set[EdgeOpOverload], |
697 | | - ) -> list[torch.fx.Node]: |
698 | | - # Start by iterating over all the users of this node, and check |
699 | | - # if they are have their target in consumer_op_packets. |
700 | | - users = deque(producer.users.keys()) |
701 | | - # This holds the list of the user ops that directly (or transitively |
702 | | - # via view/slice) consume this producer_op_packets, and hence can be removed. |
703 | | - removal_candidates = [] |
704 | | - while users: |
705 | | - user = users.popleft() |
706 | | - |
707 | | - # If the user is a bypass op, we bypass it, and examine |
708 | | - # its users instead for consumer_op_packets. |
709 | | - if user.target in bypass_ops: |
710 | | - users.extend(list(user.users.keys())) |
711 | | - elif self.can_fuse_for_chain(producer, user, consumer_op_packets): |
712 | | - removal_candidates.append(user) |
713 | | - else: |
714 | | - removal_candidates.clear() |
715 | | - break |
716 | | - return removal_candidates |
717 | | - |
718 | | - def find_and_fuse( |
719 | | - self, |
720 | | - graph_module: torch.fx.GraphModule, |
721 | | - producer_op_packets: set[EdgeOpOverloadPacket], |
722 | | - consumer_op_packets: set[EdgeOpOverloadPacket], |
723 | | - bypass_ops: set[EdgeOpOverload], |
724 | | - ) -> bool: |
725 | | - """ |
726 | | - Find and fuse producer-consumer op pairs. |
727 | | -
|
728 | | - Returns True if any fusion was performed, False otherwise. |
729 | | - """ |
730 | | - modified = False |
731 | | - for node in graph_module.graph.nodes: |
732 | | - # We are only interested in ops that have overload target in |
733 | | - # producer_op. |
734 | | - if not ( |
735 | | - isinstance(node.target, EdgeOpOverload) |
736 | | - and get_edge_overload_packet(node.target) in producer_op_packets |
737 | | - ): |
738 | | - continue |
739 | | - |
740 | | - removal_candidates = self.get_fuse_candidates( |
741 | | - node, consumer_op_packets, bypass_ops |
742 | | - ) |
743 | | - |
744 | | - if len(removal_candidates) == 0: |
745 | | - # No candidates found. |
746 | | - continue |
747 | | - |
748 | | - if not self.check_ok_to_fuse(node, removal_candidates): |
749 | | - # Not ok to remove quant-dequant pairs or replace with requantize. |
750 | | - continue |
751 | | - |
752 | | - self.fuse(node, removal_candidates, graph_module) |
753 | | - modified = True |
754 | | - |
755 | | - if modified: |
756 | | - graph_module.recompile() |
757 | | - |
758 | | - return modified |
759 | | - |
760 | | - def get_fused_node( |
761 | | - self, |
762 | | - producer: torch.fx.Node, |
763 | | - consumer: torch.fx.Node, |
764 | | - graph_module: torch.fx.GraphModule, |
765 | | - ) -> torch.fx.Node: |
766 | | - return consumer |
767 | | - |
768 | | - def fuse( |
769 | | - self, |
770 | | - node: torch.fx.Node, |
771 | | - removal_candidates: list[torch.fx.Node], |
772 | | - graph_module: torch.fx.GraphModule, |
773 | | - ) -> None: |
774 | | - # Replace all the uses of the producer op with it's input. |
775 | | - node.replace_all_uses_with(cast(torch.fx.Node, node.args[0])) |
776 | | - graph_module.graph.erase_node(node) |
777 | | - |
778 | | - # Iterate over all the removal candidates (quantize op users) and generate replacements. |
779 | | - for rnode in removal_candidates: |
780 | | - rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module)) |
781 | | - graph_module.graph.erase_node(rnode) |
782 | 601 |
|
783 | 602 |
|
784 | 603 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
@@ -1123,90 +942,13 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: |
1123 | 942 |
|
1124 | 943 |
|
1125 | 944 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
1126 | | -class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass): |
1127 | | - """ |
1128 | | - Fuse transpose or permute op pairs to a single view op. |
1129 | | - (transpose or permutation) -> (quant or dequant) -> (transpose or permutation) |
1130 | | - This happens when op2(op1) == identity, modulo unitary dimensions. |
1131 | | - 'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30] |
1132 | | - so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused. |
1133 | | - """ |
1134 | | - |
1135 | | - # A list of ops that can be bypassed when looking for a |
1136 | | - # dequantize->quantize chain |
1137 | | - bypass_ops: set[EdgeOpOverload] = { |
| 945 | +class FuseTransposeOrPermuteOpPairsPass(_SharedFuseTransposeOrPermuteOpPairsPass): |
| 946 | + bypass_ops: set[EdgeOpOverload] = _SharedFuseTransposeOrPermuteOpPairsPass.bypass_ops | { |
1138 | 947 | exir_ops.edge.cadence.quantize_per_tensor.default, |
1139 | | - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
1140 | | - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, |
1141 | 948 | exir_ops.edge.cadence.dequantize_per_tensor.default, |
1142 | | - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
1143 | | - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, |
1144 | 949 | exir_ops.edge.cadence.quantized_relu.per_tensor, |
1145 | 950 | } |
1146 | 951 |
|
1147 | | - def can_fuse_for_chain( |
1148 | | - self, |
1149 | | - producer: torch.fx.Node, |
1150 | | - consumer: torch.fx.Node, |
1151 | | - consumer_op_packets: set[EdgeOpOverloadPacket], |
1152 | | - ) -> bool: |
1153 | | - if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): |
1154 | | - return False |
1155 | | - |
1156 | | - # checking that permut2(permut1(identity)) == identity, modulo unitary dimensions |
1157 | | - producer_input = cast(torch.fx.Node, producer.args[0]) |
1158 | | - if "val" not in producer_input.meta: |
1159 | | - return False |
1160 | | - input_shape = producer_input.meta["val"].shape |
1161 | | - ident_dims = list(range(len(input_shape))) |
1162 | | - # this mapping helps to handle both transpose and permutations |
1163 | | - f: dict[Any, Callable] = { |
1164 | | - exir_ops.edge.aten.transpose_copy.int: get_transposed_dims, |
1165 | | - exir_ops.edge.aten.permute_copy.default: get_permuted_dims, |
1166 | | - } |
1167 | | - in_dims = f[producer.target](producer, ident_dims) |
1168 | | - out_dims = f[consumer.target](consumer, in_dims) |
1169 | | - # Filtering out unitary dimensions |
1170 | | - non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1] |
1171 | | - non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1] |
1172 | | - return non_unit_out_dims == non_unit_ident_dims |
1173 | | - |
1174 | | - def get_fused_node( |
1175 | | - self, |
1176 | | - producer: torch.fx.Node, |
1177 | | - consumer: torch.fx.Node, |
1178 | | - graph_module: torch.fx.GraphModule, |
1179 | | - ) -> torch.fx.Node: |
1180 | | - # This step is important because of how we can fuse transpositions that are not perfectly |
1181 | | - # reverse one of another but will be fused if there are unitary dimensions. |
1182 | | - # The fused operation must have the same output shape as the consumer. |
1183 | | - output_shape = consumer.meta["val"].shape |
1184 | | - with graph_module.graph.inserting_after(consumer): |
1185 | | - view = graph_module.graph.call_function( |
1186 | | - exir_ops.edge.aten.view_copy.default, |
1187 | | - (consumer.args[0], output_shape), |
1188 | | - {}, |
1189 | | - ) |
1190 | | - return view |
1191 | | - |
1192 | | - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
1193 | | - # Remove any transpose/permutation op pair that cancel each other. |
1194 | | - modified = self.find_and_fuse( |
1195 | | - graph_module, |
1196 | | - producer_op_packets={ |
1197 | | - exir_ops.edge.aten.transpose_copy, |
1198 | | - exir_ops.edge.aten.permute_copy, |
1199 | | - }, |
1200 | | - consumer_op_packets={ |
1201 | | - exir_ops.edge.aten.transpose_copy, |
1202 | | - exir_ops.edge.aten.permute_copy, |
1203 | | - }, |
1204 | | - bypass_ops=self.bypass_ops, |
1205 | | - ) |
1206 | | - if modified: |
1207 | | - return super().call(graph_module) |
1208 | | - return PassResult(graph_module, False) |
1209 | | - |
1210 | 952 |
|
1211 | 953 | @register_cadence_pass(CadencePassAttribute(opt_level=1)) |
1212 | 954 | class FuseFullThenReshapePass(RemoveOrReplacePassInterface): |
|
0 commit comments