Skip to content

Commit fc7560f

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Handle rank-changing views in RemovePermutesAroundElementwiseOps (#19538)
Summary: Extend RemovePermutesAroundElementwiseOps to cancel permute pairs across rank-changing squeeze/unsqueeze view boundaries. When a permute's sole user is a view_copy that adds or removes a single size-1 dimension, the pass adapts the expected permutation to the new rank and continues traversal. This enables removing permutes that sit on opposite sides of an unsqueeze→elementwise→squeeze chain (e.g. the NHWC↔NTC layout conversion around convolutions in the cascade detector model). Key changes: - Accept extra_permutable_ops constructor parameter for backend-specific ops - Track per-node expected permutations across view boundaries - Run dimension updates before edges_in bypass to preserve original metadata - Handle view_copy, unsqueeze_copy, squeeze_copy rank changes - Treat aten.full.default as a compile-time constant Note: The PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView pass is removed from the Arm pass manager, since it doesn't actually help anymore. Reviewed By: DrJessop Differential Revision: D104775244
1 parent 371cb1c commit fc7560f

6 files changed

Lines changed: 578 additions & 99 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,6 @@
160160
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
161161
FuseCascadedTransposeOrPermuteOps,
162162
)
163-
from executorch.backends.transforms.postpone_permute_below_squeeze_view import (
164-
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
165-
)
166-
167163
from executorch.exir import ExportedProgram
168164
from executorch.exir.pass_base import ExportPass
169165
from executorch.exir.pass_manager import PassManager
@@ -531,7 +527,6 @@ def _tosa_pipeline(
531527
RewritePadPass(),
532528
FuseViewCopyTransformPass(),
533529
RemovePermutesAroundElementwiseTosaOps(),
534-
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(),
535530
FuseCascadedTransposeOrPermuteOps(),
536531
ConvertPermuteSingletonToViewPass(),
537532
RewriteHighRankSingletonPermutePass(),

backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212

1313
class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps):
14-
permutable_ops = {
15-
*RemovePermutesAroundElementwiseOps.permutable_ops,
16-
*TableOps.unary_table_ops.keys(),
17-
*TableOps.special_table_ops,
18-
exir_ops.backend.tosa.RESCALE.default,
19-
exir_ops.backend.tosa.TABLE.default,
20-
}
14+
def __init__(self) -> None:
15+
super().__init__(
16+
extra_permutable_ops={
17+
*TableOps.unary_table_ops.keys(),
18+
*TableOps.special_table_ops,
19+
exir_ops.backend.tosa.RESCALE.default,
20+
exir_ops.backend.tosa.TABLE.default,
21+
}
22+
)
2123

2224
def permute_subgraph(self, subgraph):
2325
# Original function will always permute constant nodes which is wrong for table ops

backends/cadence/aot/remove_ops.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -603,16 +603,16 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
603603

604604
@register_cadence_pass(CadencePassAttribute(opt_level=2))
605605
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
606-
permutable_ops: set[EdgeOpOverload] = (
607-
_SharedRemovePermutesAroundElementwiseOps.permutable_ops
608-
| {
609-
exir_ops.edge.cadence.quantize_per_tensor.default,
610-
exir_ops.edge.cadence.dequantize_per_tensor.default,
611-
exir_ops.edge.cadence.quantized_relu.per_tensor,
612-
exir_ops.edge.cadence.requantize.per_tensor,
613-
exir_ops.edge.cadence.quantized_add.per_tensor,
614-
}
615-
)
606+
def __init__(self) -> None:
607+
super().__init__(
608+
extra_permutable_ops={
609+
exir_ops.edge.cadence.quantize_per_tensor.default,
610+
exir_ops.edge.cadence.dequantize_per_tensor.default,
611+
exir_ops.edge.cadence.quantized_relu.per_tensor,
612+
exir_ops.edge.cadence.requantize.per_tensor,
613+
exir_ops.edge.cadence.quantized_add.per_tensor,
614+
}
615+
)
616616

617617

618618
@register_cadence_pass(CadencePassAttribute(opt_level=2))

0 commit comments

Comments
 (0)