Skip to content

Commit bb8197e

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Handle rank-changing views in RemovePermutesAroundElementwiseOps
Summary: Extend RemovePermutesAroundElementwiseOps to cancel permute pairs across rank-changing squeeze/unsqueeze view boundaries. When a permute's sole user is a view_copy that adds or removes a single size-1 dimension, the pass adapts the expected permutation to the new rank and continues traversal. This enables removing permutes that sit on opposite sides of an unsqueeze→elementwise→squeeze chain (e.g. the NHWC↔NTC layout conversion around convolutions in the cascade detector model). Key changes: - Accept extra_permutable_ops constructor parameter for backend-specific ops - Track per-node expected permutations across view boundaries - Run dimension updates before edges_in bypass to preserve original metadata - Handle view_copy, unsqueeze_copy, squeeze_copy rank changes - Treat aten.full.default as a compile-time constant Note: The PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView pass is removed from the Arm pass manager, since it doesn't actually help anymore. Differential Revision: D104775244
1 parent b04cc65 commit bb8197e

6 files changed

Lines changed: 442 additions & 100 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,6 @@
161161
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
162162
FuseCascadedTransposeOrPermuteOps,
163163
)
164-
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
165-
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
166-
)
167-
168164
from executorch.exir import ExportedProgram
169165
from executorch.exir.pass_base import ExportPass
170166
from executorch.exir.pass_manager import PassManager
@@ -538,7 +534,6 @@ def _tosa_pipeline(
538534
RewritePadPass(),
539535
FuseViewCopyTransformPass(),
540536
RemovePermutesAroundElementwiseTosaOps(),
541-
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
542537
FuseCascadedTransposeOrPermuteOps(),
543538
ConvertPermuteSingletonToViewPass(),
544539
RewriteHighRankSingletonPermutePass(),

backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212

1313
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
14-
permutable_ops = {
15-
*RemovePermutesAroundElementwiseOps.permutable_ops,
16-
*TableOps.unary_table_ops.keys(),
17-
*TableOps.special_table_ops,
18-
exir_ops.backend.tosa.RESCALE.default,
19-
exir_ops.backend.tosa.TABLE.default,
20-
}
14+
def __init__(self) -> None:
15+
super().__init__(
16+
extra_permutable_ops={
17+
*TableOps.unary_table_ops.keys(),
18+
*TableOps.special_table_ops,
19+
exir_ops.backend.tosa.RESCALE.default,
20+
exir_ops.backend.tosa.TABLE.default,
21+
}
22+
)
2123

2224
def permute_subgraph(self, subgraph):
2325
# Original function will always permute constant nodes which is wrong for table ops

backends/cadence/aot/remove_ops.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -603,16 +603,16 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
603603

604604
@register_cadence_pass(CadencePassAttribute(opt_level=2))
605605
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
606-
permutable_ops: set[EdgeOpOverload] = (
607-
_SharedRemovePermutesAroundElementwiseOps.permutable_ops
608-
| {
609-
exir_ops.edge.cadence.quantize_per_tensor.default,
610-
exir_ops.edge.cadence.dequantize_per_tensor.default,
611-
exir_ops.edge.cadence.quantized_relu.per_tensor,
612-
exir_ops.edge.cadence.requantize.per_tensor,
613-
exir_ops.edge.cadence.quantized_add.per_tensor,
614-
}
615-
)
606+
def __init__(self) -> None:
607+
super().__init__(
608+
extra_permutable_ops={
609+
exir_ops.edge.cadence.quantize_per_tensor.default,
610+
exir_ops.edge.cadence.dequantize_per_tensor.default,
611+
exir_ops.edge.cadence.quantized_relu.per_tensor,
612+
exir_ops.edge.cadence.requantize.per_tensor,
613+
exir_ops.edge.cadence.quantized_add.per_tensor,
614+
}
615+
)
616616

617617

618618
@register_cadence_pass(CadencePassAttribute(opt_level=2))

0 commit comments

Comments
 (0)