Skip to content

Commit 443d96a

Browse files
authored
Reorder slice before permute (#19129)
Differential Revision: D102426699 Pull Request resolved: #19129
1 parent 904c667 commit 443d96a

2 files changed

Lines changed: 207 additions & 1 deletion

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111

1212
from collections import defaultdict
1313
from math import prod
14-
from typing import DefaultDict, List, Tuple
14+
from typing import cast, DefaultDict, List, Tuple
1515

1616
import torch
1717
import torch.fx
1818
from executorch.backends.cadence.aot.compiler_utils import get_placeholders, get_shape
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
21+
get_arg,
2122
get_overload_packet,
2223
register_cadence_pass,
2324
RemoveOrReplacePassInterface,
@@ -641,6 +642,83 @@ class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(
641642
pass
642643

643644

645+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
646+
class MoveSliceBeforePermutePass(RemoveOrReplacePassInterface):
647+
"""Move slice_copy ops before permute_copy to reduce permute data volume.
648+
649+
Rewrites permute(input, perm) -> slice(dim=D) into
650+
slice(input, dim=perm[D]) -> permute(sliced, perm), so the permute
651+
operates on a smaller tensor.
652+
653+
Scans slice nodes and matches upstream permutes. This also handles
654+
chained cases (permute -> slice -> slice) in one pass: each slice
655+
independently checks its input for a permute.
656+
657+
Cost model: dim-0 slices are nop-eligible (zero-copy pointer offset
658+
after MakeSliceAndCatDimOutermostPass). Moving such a slice loses the
659+
nop, so we only move it when the permute savings outweigh the nop loss,
660+
i.e. when the slice removes more than half the data (full > 2 * sliced).
661+
Non-dim-0 slices have no nop opportunity, so any permute savings is
662+
pure win.
663+
"""
664+
665+
STRIDED_SLICE_COST_FACTOR: int = 2
666+
667+
@property
668+
def targets(self) -> list[EdgeOpOverload]:
669+
return [exir_ops.edge.aten.slice_copy.Tensor]
670+
671+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
672+
permute_node = get_arg(node, "input", torch.fx.Node)
673+
if permute_node.target != exir_ops.edge.aten.permute_copy.default:
674+
return False
675+
676+
if len(permute_node.users) != 1:
677+
return False
678+
679+
perm = cast(list[int], permute_node.args[1])
680+
permute_input = permute_node.args[0]
681+
assert isinstance(permute_input, torch.fx.Node)
682+
683+
slice_dim = get_arg(node, "dim", int)
684+
685+
full_size = prod(permute_node.meta["val"].shape)
686+
sliced_size = prod(node.meta["val"].shape)
687+
if slice_dim == 0 and full_size <= self.STRIDED_SLICE_COST_FACTOR * sliced_size:
688+
return False
689+
690+
new_dim = perm[slice_dim]
691+
graph = node.graph
692+
693+
with graph.inserting_before(permute_node):
694+
new_slice_args = (
695+
permute_input,
696+
new_dim,
697+
get_arg(node, "start"),
698+
get_arg(node, "end"),
699+
get_arg(node, "step", int),
700+
)
701+
new_slice = graph.create_node(
702+
"call_function",
703+
exir_ops.edge.aten.slice_copy.Tensor,
704+
args=new_slice_args,
705+
)
706+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
707+
permute_input.meta["val"], *new_slice_args[1:]
708+
)
709+
new_permute = graph.create_node(
710+
"call_function",
711+
exir_ops.edge.aten.permute_copy.default,
712+
args=(new_slice, perm),
713+
)
714+
new_permute.meta["val"] = exir_ops.edge.aten.permute_copy.default(
715+
new_slice.meta["val"], perm
716+
)
717+
718+
node.replace_all_uses_with(new_permute)
719+
return True
720+
721+
644722
# The following class consolidates functions to reoder ops (i.e., either hoist
645723
# or sink some ops in the graph).
646724
class CadenceReorderOpsInGraph:

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.cadence.aot.reorder_ops import (
2424
AdvanceQuantizeOpAboveDefChainPass,
2525
AdvanceQuantizeOpAboveDefInBranchPass,
26+
MoveSliceBeforePermutePass,
2627
PostponeDequantizeOpBelowUseChainPass,
2728
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
2829
SinkOpsCloserToUsePass,
@@ -633,3 +634,130 @@ def test_permute_view_chains_neg(self) -> None:
633634
self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy)
634635
self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy)
635636
self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy)
637+
638+
639+
class TestMoveSliceBeforePermutePass(unittest.TestCase):
640+
def test_basic_move(self) -> None:
641+
"""permute → slice becomes slice → permute."""
642+
builder = GraphBuilder()
643+
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
644+
permuted = builder.call_operator(
645+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
646+
)
647+
sliced = builder.call_operator(
648+
op=exir_ops.edge.aten.slice_copy.Tensor,
649+
args=(permuted, 1, 0, 2, 1),
650+
)
651+
builder.output([sliced])
652+
original = builder.get_graph_module()
653+
654+
result = transform_and_check_numerics(
655+
original,
656+
(torch.randn(2, 3, 4, 5),),
657+
MoveSliceBeforePermutePass(),
658+
)
659+
self.assertTrue(result.modified)
660+
661+
nodes = get_compute_nodes_in_gm(result.graph_module)
662+
self.assertEqual(len(nodes), 2)
663+
self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy)
664+
self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy)
665+
666+
def test_multi_user_permute_no_change(self) -> None:
667+
"""Permute with multiple users → no change (only single-user supported)."""
668+
builder = GraphBuilder()
669+
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
670+
permuted = builder.call_operator(
671+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
672+
)
673+
slice1 = builder.call_operator(
674+
op=exir_ops.edge.aten.slice_copy.Tensor,
675+
args=(permuted, 1, 0, 2, 1),
676+
)
677+
slice2 = builder.call_operator(
678+
op=exir_ops.edge.aten.slice_copy.Tensor,
679+
args=(permuted, 2, 1, 3, 1),
680+
)
681+
builder.output([slice1, slice2])
682+
original = builder.get_graph_module()
683+
684+
result = cast(PassResult, MoveSliceBeforePermutePass()(original))
685+
self.assertFalse(result.modified)
686+
687+
def test_no_slice_users_no_change(self) -> None:
688+
"""Permute with no slice users → no change."""
689+
builder = GraphBuilder()
690+
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
691+
permuted = builder.call_operator(
692+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
693+
)
694+
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,))
695+
builder.output([neg])
696+
original = builder.get_graph_module()
697+
698+
result = cast(PassResult, MoveSliceBeforePermutePass()(original))
699+
self.assertFalse(result.modified)
700+
701+
def test_dim0_slice_large_reduction_moved(self) -> None:
702+
"""Dim-0 slice removing >50% of data → profitable, moved."""
703+
builder = GraphBuilder()
704+
x = builder.placeholder("x", torch.randn(10, 3, 4, 5))
705+
permuted = builder.call_operator(
706+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
707+
)
708+
sliced = builder.call_operator(
709+
op=exir_ops.edge.aten.slice_copy.Tensor,
710+
args=(permuted, 0, 0, 2, 1),
711+
)
712+
builder.output([sliced])
713+
original = builder.get_graph_module()
714+
715+
result = transform_and_check_numerics(
716+
original,
717+
(torch.randn(10, 3, 4, 5),),
718+
MoveSliceBeforePermutePass(),
719+
)
720+
self.assertTrue(result.modified)
721+
722+
nodes = get_compute_nodes_in_gm(result.graph_module)
723+
self.assertEqual(len(nodes), 2)
724+
self.assertEqual(nodes[0], exir_ops.edge.aten.slice_copy)
725+
self.assertEqual(nodes[1], exir_ops.edge.aten.permute_copy)
726+
727+
def test_dim0_slice_small_reduction_not_moved(self) -> None:
728+
"""Dim-0 slice removing <50% of data → not profitable, kept."""
729+
builder = GraphBuilder()
730+
x = builder.placeholder("x", torch.randn(10, 3, 4, 5))
731+
permuted = builder.call_operator(
732+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
733+
)
734+
sliced = builder.call_operator(
735+
op=exir_ops.edge.aten.slice_copy.Tensor,
736+
args=(permuted, 0, 0, 8, 1),
737+
)
738+
builder.output([sliced])
739+
original = builder.get_graph_module()
740+
741+
result = cast(PassResult, MoveSliceBeforePermutePass()(original))
742+
self.assertFalse(result.modified)
743+
744+
def test_non_dim0_slice_always_moved(self) -> None:
745+
"""Non-dim-0 slice → always profitable, moved regardless of reduction."""
746+
builder = GraphBuilder()
747+
x = builder.placeholder("x", torch.randn(10, 3, 4, 5))
748+
permuted = builder.call_operator(
749+
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
750+
)
751+
sliced = builder.call_operator(
752+
op=exir_ops.edge.aten.slice_copy.Tensor,
753+
args=(permuted, 2, 0, 3, 1),
754+
)
755+
builder.output([sliced])
756+
original = builder.get_graph_module()
757+
758+
result = transform_and_check_numerics(
759+
original,
760+
(torch.randn(10, 3, 4, 5),),
761+
MoveSliceBeforePermutePass(),
762+
)
763+
self.assertTrue(result.modified)

0 commit comments

Comments
 (0)