Skip to content

Commit 831b7d4

Browse files
authored
Remove or move permute after mean (#19103)
Differential Revision: D102268214 Pull Request resolved: #19103
1 parent 443d96a commit 831b7d4

2 files changed

Lines changed: 590 additions & 0 deletions

File tree

backends/cadence/aot/remove_ops.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from executorch.exir.pass_manager import PassManager, PassType
3535
from executorch.exir.passes import dead_code_elimination_pass
3636
from torch.fx.node import Node
37+
from torch.utils import _pytree as pytree
3738

3839

3940
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -387,6 +388,219 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
387388
return False
388389

389390

391+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
392+
class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface):
393+
"""Remove or sink permute ops that precede mean reductions through unary chains.
394+
395+
When a permute feeds into a mean (possibly through unary ops like
396+
dequantize/quantize), two optimizations apply:
397+
398+
1. If non-reduced dims maintain their relative order and positions, the
399+
permute is fully removed and the mean's reduction dims are remapped.
400+
2. Otherwise, the permute is moved after the mean so it operates on
401+
smaller data.
402+
403+
Cost model
404+
----------
405+
Let S_in = input size in bytes, S_out = output size in bytes,
406+
R = S_in / S_out (reduction ratio), C_c = contiguous mean compute
407+
cost, C_s = strided mean compute cost, and Delta = C_s - C_c.
408+
409+
Original graph (permute -> mean):
410+
Cost_orig = 3*S_in + S_out + C_c
411+
Breakdown: permute reads and writes S_in (2*S_in), mean reads
412+
S_in and writes S_out.
413+
414+
Case 1 -- full removal (mean with remapped dims, no permute):
415+
Cost_remove = S_in + S_out + C_s
416+
Profitable when: Delta < 2*S_in
417+
The strided access penalty must be less than twice the full
418+
input tensor I/O (the eliminated permute cost).
419+
420+
Case 2 -- reorder (mean with remapped dims -> small permute):
421+
Cost_reorder = S_in + 3*S_out + C_s
422+
Profitable when: Delta < 2*(S_in - S_out) = 2*S_in*(R-1)/R
423+
At R = 4 the threshold is 1.5*S_in; at R = 16 it approaches
424+
2*S_in, converging to the full-removal bound as S_out becomes
425+
negligible.
426+
427+
Full removal always dominates reorder when both are structurally
428+
possible (Cost_remove < Cost_reorder since S_out < 3*S_out), so
429+
removal is applied without a cost gate.
430+
431+
Additionally, if the original permute does not place the reduction
432+
dims as the trailing (innermost) dimensions, the mean is already
433+
strided in the original graph. Removing or sinking the permute
434+
saves the permute I/O (2*S_in) while the mean performance can only
435+
improve or stay the same. This makes the transformation
436+
unconditionally profitable without needing the cost model.
437+
"""
438+
439+
_UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset(
440+
{
441+
exir_ops.edge.cadence.dequantize_per_tensor.default,
442+
exir_ops.edge.cadence.quantize_per_tensor.default,
443+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
444+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
445+
exir_ops.edge.aten.clone.default,
446+
exir_ops.edge.aten.relu.default,
447+
exir_ops.edge.aten.neg.default,
448+
exir_ops.edge.aten.abs.default,
449+
}
450+
)
451+
452+
_MIN_REDUCTION_RATIO_FOR_REORDER: int = 4
453+
454+
@property
455+
def targets(self) -> list[EdgeOpOverload]:
456+
return [exir_ops.edge.aten.mean.dim]
457+
458+
def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]:
459+
"""Walk backward from mean through single-user unary ops to find a permute."""
460+
current = mean_node.args[0]
461+
if not isinstance(current, Node):
462+
return None
463+
while True:
464+
if current.target == exir_ops.edge.aten.permute_copy.default:
465+
return current
466+
if current.target not in self._UNARY_TARGETS:
467+
return None
468+
if len(current.users) != 1:
469+
return None
470+
parent = current.args[0]
471+
if not isinstance(parent, Node):
472+
return None
473+
current = parent
474+
475+
@staticmethod
476+
def _reduction_dims_are_trailing(reduction_dims: list[int], ndim: int) -> bool:
477+
"""Check whether all reduction dims are the trailing (innermost) dims."""
478+
canonical = sorted(d % ndim for d in reduction_dims)
479+
return canonical == list(range(ndim - len(canonical), ndim))
480+
481+
def _is_reorder_profitable(
482+
self,
483+
reduction_dims: list[int],
484+
new_reduction_dims: list[int],
485+
ndim: int,
486+
input_shape: torch.Size,
487+
) -> bool:
488+
"""Determine whether sinking the permute past the mean is profitable.
489+
490+
Three cases, in order of evaluation:
491+
492+
1. The original permute does not place the reduction dims as trailing
493+
(innermost) dimensions. The mean is already strided, so removing
494+
the permute saves 2*S_in with no mean degradation. Always
495+
profitable.
496+
497+
2. The new (remapped) reduction dims are trailing in the original
498+
layout. The mean stays contiguous after the transformation
499+
(Delta ~ 0). Always profitable.
500+
501+
3. The mean becomes strided. The reorder is profitable when
502+
Delta < 2*S_in*(R-1)/R. We approximate this by requiring
503+
R >= _MIN_REDUCTION_RATIO_FOR_REORDER, ensuring the I/O savings
504+
from operating on a smaller tensor outweigh the strided access
505+
penalty.
506+
"""
507+
if not self._reduction_dims_are_trailing(reduction_dims, ndim):
508+
return True
509+
510+
if self._reduction_dims_are_trailing(new_reduction_dims, ndim):
511+
return True
512+
513+
reduction_ratio = 1
514+
canonical_reduction = set(new_reduction_dims)
515+
for d in canonical_reduction:
516+
reduction_ratio *= input_shape[d]
517+
518+
return reduction_ratio >= self._MIN_REDUCTION_RATIO_FOR_REORDER
519+
520+
def maybe_remove_or_replace(self, node: Node) -> bool:
521+
reduction_dims = get_arg(node, "dim", list[int])
522+
523+
permute_node = self._find_permute_through_unary_chain(node)
524+
if permute_node is None:
525+
return False
526+
527+
perm = get_arg(permute_node, "dims", list[int])
528+
ndim = len(perm)
529+
530+
if len(permute_node.users) != 1:
531+
return False
532+
533+
permute_input = permute_node.args[0]
534+
assert isinstance(permute_input, Node)
535+
536+
new_reduction_dims = [perm[d] for d in reduction_dims]
537+
keepdim = get_arg(node, "keepdim", bool)
538+
539+
# Determine if the permute can be fully removed (post-mean permute
540+
# would be a no-op) vs needing to be sunk after the mean.
541+
canonical_reduction = set(new_reduction_dims)
542+
if keepdim:
543+
can_remove = all(
544+
perm[d] == d for d in range(ndim) if d not in canonical_reduction
545+
)
546+
else:
547+
non_reduced_in_perm_order = [
548+
d for d in perm if d not in canonical_reduction
549+
]
550+
can_remove = non_reduced_in_perm_order == sorted(non_reduced_in_perm_order)
551+
552+
# Full removal is almost always profitable. Reorder requires a
553+
# tighter cost bound; verify via the cost model before proceeding.
554+
if not can_remove:
555+
input_shape = permute_input.meta["val"].shape
556+
if not self._is_reorder_profitable(
557+
reduction_dims, new_reduction_dims, ndim, input_shape
558+
):
559+
return False
560+
561+
# Rewire: the permute's single user (either the mean itself or the
562+
# first unary op in the chain) should read from the permute's input.
563+
permute_user = next(iter(permute_node.users))
564+
permute_user.replace_input_with(permute_node, permute_input)
565+
node.args = (node.args[0], new_reduction_dims) + node.args[2:]
566+
567+
# Re-derive the mean's meta since its reduction dims changed.
568+
fake_args = pytree.tree_map(
569+
lambda x: x.meta["val"] if isinstance(x, Node) else x,
570+
node.args,
571+
)
572+
node.meta["val"] = node.target(*fake_args) # pyre-ignore[29]
573+
574+
if not can_remove:
575+
# Compute the post-mean permute on the reduced output.
576+
if keepdim:
577+
post_perm = list(perm)
578+
else:
579+
non_reduced_original = sorted(
580+
d for d in range(ndim) if d not in canonical_reduction
581+
)
582+
non_reduced_permuted = [d for d in perm if d not in canonical_reduction]
583+
post_perm = [
584+
non_reduced_original.index(d) for d in non_reduced_permuted
585+
]
586+
587+
graph = node.graph
588+
with graph.inserting_after(node):
589+
new_permute = graph.create_node(
590+
"call_function",
591+
exir_ops.edge.aten.permute_copy.default,
592+
args=(node, post_perm),
593+
)
594+
new_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default(
595+
node.meta["val"], post_perm
596+
)
597+
for user in list(node.users):
598+
if user is not new_permute:
599+
user.replace_input_with(node, new_permute)
600+
601+
return True
602+
603+
390604
@register_cadence_pass(CadencePassAttribute(opt_level=2))
391605
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
392606
permutable_ops: set[EdgeOpOverload] = (
@@ -646,6 +860,7 @@ class CommonRemovePasses:
646860
RemoveNopSliceOrViewOpPass,
647861
RemoveToOpsPass,
648862
RemoveZeroSizedCatArgsPass,
863+
RemovePermuteBeforeMeanPass,
649864
RemovePermutesAroundElementwiseOps,
650865
FuseTransposeOrPermuteOpPairsPass,
651866
RemoveSqueezeViewBeforeElementwiseOps,

0 commit comments

Comments
 (0)