Skip to content

Commit 4bfc8af

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Remove or move permute after mean (#19103)
Summary: If we have a permute -> unary chain -> mean, based on the reduction dims of the mean, we can either fully remove the preceding permute or move the permute after the mean. Case 1: Dims after permute are still in same order with respect to each other, we can fully get rid of the permute and just update the reduction dims of the mean. Case 2: Not case 1. In this case, it's better to move the permute after the mean, since the permute will operate on less data. Reviewed By: abeakkas Differential Revision: D102268214
1 parent d0b7934 commit 4bfc8af

2 files changed

Lines changed: 590 additions & 0 deletions

File tree

backends/cadence/aot/remove_ops.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from executorch.exir.pass_manager import PassManager, PassType
3535
from executorch.exir.passes import dead_code_elimination_pass
3636
from torch.fx.node import Node
37+
from torch.utils import _pytree as pytree
3738

3839

3940
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -387,6 +388,219 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
387388
return False
388389

389390

391+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
392+
class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface):
393+
"""Remove or sink permute ops that precede mean reductions through unary chains.
394+
395+
When a permute feeds into a mean (possibly through unary ops like
396+
dequantize/quantize), two optimizations apply:
397+
398+
1. If non-reduced dims maintain their relative order and positions, the
399+
permute is fully removed and the mean's reduction dims are remapped.
400+
2. Otherwise, the permute is moved after the mean so it operates on
401+
smaller data.
402+
403+
Cost model
404+
----------
405+
Let S_in = input size in bytes, S_out = output size in bytes,
406+
R = S_in / S_out (reduction ratio), C_c = contiguous mean compute
407+
cost, C_s = strided mean compute cost, and Delta = C_s - C_c.
408+
409+
Original graph (permute -> mean):
410+
Cost_orig = 3*S_in + S_out + C_c
411+
Breakdown: permute reads and writes S_in (2*S_in), mean reads
412+
S_in and writes S_out.
413+
414+
Case 1 -- full removal (mean with remapped dims, no permute):
415+
Cost_remove = S_in + S_out + C_s
416+
Profitable when: Delta < 2*S_in
417+
The strided access penalty must be less than twice the full
418+
input tensor I/O (the eliminated permute cost).
419+
420+
Case 2 -- reorder (mean with remapped dims -> small permute):
421+
Cost_reorder = S_in + 3*S_out + C_s
422+
Profitable when: Delta < 2*(S_in - S_out) = 2*S_in*(R-1)/R
423+
At R = 4 the threshold is 1.5*S_in; at R = 16 it approaches
424+
2*S_in, converging to the full-removal bound as S_out becomes
425+
negligible.
426+
427+
Full removal always dominates reorder when both are structurally
428+
possible (Cost_remove < Cost_reorder since S_out < 3*S_out), so
429+
removal is applied without a cost gate.
430+
431+
Additionally, if the original permute does not place the reduction
432+
dims as the trailing (innermost) dimensions, the mean is already
433+
strided in the original graph. Removing or sinking the permute
434+
saves the permute I/O (2*S_in) while the mean performance can only
435+
improve or stay the same. This makes the transformation
436+
unconditionally profitable without needing the cost model.
437+
"""
438+
439+
_UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset(
440+
{
441+
exir_ops.edge.cadence.dequantize_per_tensor.default,
442+
exir_ops.edge.cadence.quantize_per_tensor.default,
443+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
444+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
445+
exir_ops.edge.aten.clone.default,
446+
exir_ops.edge.aten.relu.default,
447+
exir_ops.edge.aten.neg.default,
448+
exir_ops.edge.aten.abs.default,
449+
}
450+
)
451+
452+
_MIN_REDUCTION_RATIO_FOR_REORDER: int = 4
453+
454+
@property
455+
def targets(self) -> list[EdgeOpOverload]:
456+
return [exir_ops.edge.aten.mean.dim]
457+
458+
def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]:
459+
"""Walk backward from mean through single-user unary ops to find a permute."""
460+
current = mean_node.args[0]
461+
if not isinstance(current, Node):
462+
return None
463+
while True:
464+
if current.target == exir_ops.edge.aten.permute_copy.default:
465+
return current
466+
if current.target not in self._UNARY_TARGETS:
467+
return None
468+
if len(current.users) != 1:
469+
return None
470+
parent = current.args[0]
471+
if not isinstance(parent, Node):
472+
return None
473+
current = parent
474+
475+
@staticmethod
476+
def _reduction_dims_are_trailing(reduction_dims: list[int], ndim: int) -> bool:
477+
"""Check whether all reduction dims are the trailing (innermost) dims."""
478+
canonical = sorted(d % ndim for d in reduction_dims)
479+
return canonical == list(range(ndim - len(canonical), ndim))
480+
481+
def _is_reorder_profitable(
482+
self,
483+
reduction_dims: list[int],
484+
new_reduction_dims: list[int],
485+
ndim: int,
486+
input_shape: torch.Size,
487+
) -> bool:
488+
"""Determine whether sinking the permute past the mean is profitable.
489+
490+
Three cases, in order of evaluation:
491+
492+
1. The original permute does not place the reduction dims as trailing
493+
(innermost) dimensions. The mean is already strided, so removing
494+
the permute saves 2*S_in with no mean degradation. Always
495+
profitable.
496+
497+
2. The new (remapped) reduction dims are trailing in the original
498+
layout. The mean stays contiguous after the transformation
499+
(Delta ~ 0). Always profitable.
500+
501+
3. The mean becomes strided. The reorder is profitable when
502+
Delta < 2*S_in*(R-1)/R. We approximate this by requiring
503+
R >= _MIN_REDUCTION_RATIO_FOR_REORDER, ensuring the I/O savings
504+
from operating on a smaller tensor outweigh the strided access
505+
penalty.
506+
"""
507+
if not self._reduction_dims_are_trailing(reduction_dims, ndim):
508+
return True
509+
510+
if self._reduction_dims_are_trailing(new_reduction_dims, ndim):
511+
return True
512+
513+
reduction_ratio = 1
514+
canonical_reduction = set(new_reduction_dims)
515+
for d in canonical_reduction:
516+
reduction_ratio *= input_shape[d]
517+
518+
return reduction_ratio >= self._MIN_REDUCTION_RATIO_FOR_REORDER
519+
520+
def maybe_remove_or_replace(self, node: Node) -> bool:
521+
reduction_dims = get_arg(node, "dim", list[int])
522+
523+
permute_node = self._find_permute_through_unary_chain(node)
524+
if permute_node is None:
525+
return False
526+
527+
perm = get_arg(permute_node, "dims", list[int])
528+
ndim = len(perm)
529+
530+
if len(permute_node.users) != 1:
531+
return False
532+
533+
permute_input = permute_node.args[0]
534+
assert isinstance(permute_input, Node)
535+
536+
new_reduction_dims = [perm[d] for d in reduction_dims]
537+
keepdim = get_arg(node, "keepdim", bool)
538+
539+
# Determine if the permute can be fully removed (post-mean permute
540+
# would be a no-op) vs needing to be sunk after the mean.
541+
canonical_reduction = set(new_reduction_dims)
542+
if keepdim:
543+
can_remove = all(
544+
perm[d] == d for d in range(ndim) if d not in canonical_reduction
545+
)
546+
else:
547+
non_reduced_in_perm_order = [
548+
d for d in perm if d not in canonical_reduction
549+
]
550+
can_remove = non_reduced_in_perm_order == sorted(non_reduced_in_perm_order)
551+
552+
# Full removal is almost always profitable. Reorder requires a
553+
# tighter cost bound; verify via the cost model before proceeding.
554+
if not can_remove:
555+
input_shape = permute_input.meta["val"].shape
556+
if not self._is_reorder_profitable(
557+
reduction_dims, new_reduction_dims, ndim, input_shape
558+
):
559+
return False
560+
561+
# Rewire: the permute's single user (either the mean itself or the
562+
# first unary op in the chain) should read from the permute's input.
563+
permute_user = next(iter(permute_node.users))
564+
permute_user.replace_input_with(permute_node, permute_input)
565+
node.args = (node.args[0], new_reduction_dims) + node.args[2:]
566+
567+
# Re-derive the mean's meta since its reduction dims changed.
568+
fake_args = pytree.tree_map(
569+
lambda x: x.meta["val"] if isinstance(x, Node) else x,
570+
node.args,
571+
)
572+
node.meta["val"] = node.target(*fake_args) # pyre-ignore[29]
573+
574+
if not can_remove:
575+
# Compute the post-mean permute on the reduced output.
576+
if keepdim:
577+
post_perm = list(perm)
578+
else:
579+
non_reduced_original = sorted(
580+
d for d in range(ndim) if d not in canonical_reduction
581+
)
582+
non_reduced_permuted = [d for d in perm if d not in canonical_reduction]
583+
post_perm = [
584+
non_reduced_original.index(d) for d in non_reduced_permuted
585+
]
586+
587+
graph = node.graph
588+
with graph.inserting_after(node):
589+
new_permute = graph.create_node(
590+
"call_function",
591+
exir_ops.edge.aten.permute_copy.default,
592+
args=(node, post_perm),
593+
)
594+
new_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default(
595+
node.meta["val"], post_perm
596+
)
597+
for user in list(node.users):
598+
if user is not new_permute:
599+
user.replace_input_with(node, new_permute)
600+
601+
return True
602+
603+
390604
@register_cadence_pass(CadencePassAttribute(opt_level=2))
391605
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
392606
permutable_ops: set[EdgeOpOverload] = (
@@ -646,6 +860,7 @@ class CommonRemovePasses:
646860
RemoveNopSliceOrViewOpPass,
647861
RemoveToOpsPass,
648862
RemoveZeroSizedCatArgsPass,
863+
RemovePermuteBeforeMeanPass,
649864
RemovePermutesAroundElementwiseOps,
650865
FuseTransposeOrPermuteOpPairsPass,
651866
RemoveSqueezeViewBeforeElementwiseOps,

0 commit comments

Comments
 (0)