1414
1515import torch
1616import torch .fx
17+
18+ from executorch .backends .cadence .aot .fuse_ops import FuseTransposeOrPermuteOpPairsPass
1719from executorch .backends .cadence .aot .pass_utils import (
1820 CadencePassAttribute ,
1921 get_arg ,
2022 register_cadence_pass ,
2123 RemoveOrReplacePassInterface ,
2224 set_arg ,
2325)
24-
2526from executorch .backends .cadence .aot .simplify_ops import SimplifySliceOpPass
2627from executorch .backends .cadence .aot .utils import get_edge_overload_packet
2728from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
3334from torch .fx .node import Node
3435
3536
36- @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
37+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
3738class RemoveCloneOpsTransformImported (ExportPass ):
3839 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
3940 finalize_passes : List [PassType ] = [
@@ -44,7 +45,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4445 return result
4546
4647
47- @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
48+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
4849class RemoveDetachCopyPass (RemoveOrReplacePassInterface ):
4950 @property
5051 def targets (self ) -> list [EdgeOpOverload ]:
@@ -66,7 +67,7 @@ class RemoveRedundantOps:
6667 ]
6768
6869
69- @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
70+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
7071class RemoveZeroSizedCatArgsPass (RemoveOrReplacePassInterface ):
7172 @property
7273 def targets (self ) -> list [EdgeOpOverload ]:
@@ -120,11 +121,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
120121 return False
121122
122123
123- @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
124+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
124125class RemoveNopExpandOpPass (RemoveOrReplacePassInterface ):
125126 """
126127 For an expand op, if the operator shape matches the expand shape, then the
127- expand is a nop.
128+ expand is a nop. This is an optimization that removes unnecessary ops.
128129 """
129130
130131 @property
@@ -143,9 +144,9 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
143144 return False
144145
145146
146- @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
147+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
147148class RemoveToOpsPass (RemoveOrReplacePassInterface ):
148- # aten.to.* as of now are all nops
149+ # aten.to.* ops are no-ops in inference - this is an optimization
149150 @property
150151 def targets (self ) -> list [EdgeOpOverload ]:
151152 return [
@@ -264,11 +265,11 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
264265 return True
265266
266267
267- @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
268+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
268269class RemoveAliasCopyOpPass (RemoveOrReplacePassInterface ):
269270 """
270-
271271 alias_copy is a no-op and can be removed.
272+ This is an optimization that removes unnecessary ops.
272273 """
273274
274275 @property
@@ -412,6 +413,9 @@ class Subgraph:
412413 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
413414 exir_ops .edge .cadence .quantize_per_tensor .default ,
414415 exir_ops .edge .cadence .dequantize_per_tensor .default ,
416+ exir_ops .edge .cadence .quantized_relu .per_tensor ,
417+ exir_ops .edge .cadence .requantize .per_tensor ,
418+ exir_ops .edge .cadence .quantized_add .per_tensor ,
415419 # Ops that require special handling.
416420 exir_ops .edge .aten .cat .default ,
417421 exir_ops .edge .aten .mean .dim ,
@@ -804,6 +808,7 @@ class CommonRemovePasses:
804808 RemoveToOpsPass ,
805809 RemoveZeroSizedCatArgsPass ,
806810 RemovePermutesAroundElementwiseOps ,
811+ FuseTransposeOrPermuteOpPairsPass ,
807812 RemoveSqueezeViewBeforeElementwiseOps ,
808813 RemoveCatFromSliceCopyPass ,
809814 RemoveCloneOpsTransformImported ,
0 commit comments