Skip to content

Commit 76d941e

Browse files
authored
More generic slice propagation before unary ops which works for non-contiguous slices (#19345)
Differential Revision: D103752840 Pull Request resolved: #19345
1 parent 563e237 commit 76d941e

2 files changed

Lines changed: 270 additions & 1 deletion

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from collections import defaultdict
1313
from math import prod
14-
from typing import cast, DefaultDict, List, Tuple
14+
from typing import Callable, cast, DefaultDict, List, Tuple
1515

1616
import torch
1717
import torch.fx
@@ -719,6 +719,109 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
719719
return True
720720

721721

722+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
723+
class PropagateSlice(RemoveOrReplacePassInterface):
724+
"""Propagate slice_copy before unary element-wise ops when the cost
725+
model indicates it reduces total data movement.
726+
727+
Supported ops (extensible via dispatch table):
728+
- quantize_per_tensor: element-wise, slice passes through unchanged
729+
- dequantize_per_tensor: element-wise, slice passes through unchanged
730+
731+
Handles any slice dim and any step size. Runs in the iterative pass
732+
loop — chains are handled by repeated application.
733+
"""
734+
735+
def __init__(self) -> None:
736+
super().__init__()
737+
elementwise_targets = [
738+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
739+
exir_ops.edge.cadence.quantize_per_tensor.default,
740+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
741+
exir_ops.edge.cadence.dequantize_per_tensor.default,
742+
]
743+
self._dispatch: dict[
744+
EdgeOpOverload,
745+
tuple[
746+
Callable[[torch.fx.Node, torch.fx.Node], bool],
747+
Callable[[torch.fx.Node, torch.fx.Node], bool],
748+
],
749+
] = {
750+
t: (self._should_swap_elementwise, self._swap_elementwise_slice)
751+
for t in elementwise_targets
752+
}
753+
754+
@property
755+
def targets(self) -> list[EdgeOpOverload]:
756+
return [exir_ops.edge.aten.slice_copy.Tensor]
757+
758+
def _should_swap_elementwise(
759+
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
760+
) -> bool:
761+
full_size = prod(op_node.meta["val"].shape)
762+
sliced_size = prod(slice_node.meta["val"].shape)
763+
return sliced_size < full_size
764+
765+
def _swap_elementwise_slice(
766+
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
767+
) -> bool:
768+
op_input = op_node.args[0]
769+
assert isinstance(op_input, torch.fx.Node)
770+
graph = slice_node.graph
771+
772+
slice_args = slice_node.args[1:]
773+
774+
with graph.inserting_before(op_node):
775+
new_slice = graph.call_function(
776+
exir_ops.edge.aten.slice_copy.Tensor,
777+
args=(op_input, *slice_args),
778+
)
779+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
780+
op_input.meta["val"], *slice_args
781+
)
782+
783+
new_args = list(op_node.args)
784+
new_args[0] = new_slice
785+
target = cast(EdgeOpOverload, op_node.target)
786+
new_op = graph.call_function(
787+
target,
788+
args=tuple(new_args),
789+
kwargs=op_node.kwargs,
790+
)
791+
new_op.meta["val"] = target(
792+
new_slice.meta["val"],
793+
*[
794+
a.meta["val"] if isinstance(a, torch.fx.Node) else a
795+
for a in new_args[1:]
796+
],
797+
**{
798+
k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
799+
for k, v in op_node.kwargs.items()
800+
},
801+
)
802+
803+
slice_node.replace_all_uses_with(new_op)
804+
graph.erase_node(slice_node)
805+
graph.erase_node(op_node)
806+
return True
807+
808+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
809+
parent = node.args[0]
810+
if not isinstance(parent, torch.fx.Node):
811+
return False
812+
if len(parent.users) != 1:
813+
return False
814+
if not isinstance(parent.target, EdgeOpOverload):
815+
return False
816+
817+
entry = self._dispatch.get(parent.target)
818+
if entry is None:
819+
return False
820+
821+
should_swap, do_swap = entry
822+
return should_swap(parent, node) and do_swap(parent, node)
823+
824+
722825
# The following class consolidates functions to reoder ops (i.e., either hoist
723826
# or sink some ops in the graph).
724827
class CadenceReorderOpsInGraph:

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MoveSliceBeforePermutePass,
2727
PostponeDequantizeOpBelowUseChainPass,
2828
PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
29+
PropagateSlice,
2930
SinkOpsCloserToUsePass,
3031
)
3132
from executorch.backends.test.graph_builder import GraphBuilder
@@ -761,3 +762,168 @@ def test_non_dim0_slice_always_moved(self) -> None:
761762
MoveSliceBeforePermutePass(),
762763
)
763764
self.assertTrue(result.modified)
765+
766+
767+
class TestPropagateSlice(unittest.TestCase):
768+
def test_swap_quantize_slice(self) -> None:
769+
builder = GraphBuilder()
770+
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
771+
quant = builder.call_operator(
772+
exir_ops.edge.cadence.quantize_per_tensor.default,
773+
args=(x, 0.5, 0, 0, 255, torch.uint8),
774+
)
775+
sliced = builder.call_operator(
776+
exir_ops.edge.aten.slice_copy.Tensor,
777+
args=(quant, 0, 0, 4, 2),
778+
)
779+
builder.output([sliced])
780+
gm = builder.get_graph_module()
781+
782+
result = PropagateSlice().call(gm)
783+
784+
self.assertTrue(result.modified)
785+
786+
slice_nodes = gm.graph.find_nodes(
787+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
788+
)
789+
self.assertEqual(len(slice_nodes), 1)
790+
slice_node = slice_nodes[0]
791+
self.assertEqual(slice_node.args[0].name, "x")
792+
self.assertEqual(list(slice_node.meta["val"].shape), [2, 60, 1, 1])
793+
794+
quant_nodes = gm.graph.find_nodes(
795+
op="call_function",
796+
target=exir_ops.edge.cadence.quantize_per_tensor.default,
797+
)
798+
self.assertEqual(len(quant_nodes), 1)
799+
self.assertEqual(quant_nodes[0].args[0], slice_node)
800+
self.assertEqual(list(quant_nodes[0].meta["val"].shape), [2, 60, 1, 1])
801+
802+
def test_swap_dequantize_slice(self) -> None:
803+
builder = GraphBuilder()
804+
x = builder.placeholder(
805+
"x", torch.randint(0, 255, (4, 60, 4, 4), dtype=torch.uint8)
806+
)
807+
dequant = builder.call_operator(
808+
exir_ops.edge.cadence.dequantize_per_tensor.default,
809+
args=(x, 0.5, 0, 0, 255, torch.uint8),
810+
)
811+
sliced = builder.call_operator(
812+
exir_ops.edge.aten.slice_copy.Tensor,
813+
args=(dequant, 0, 0, 4, 2),
814+
)
815+
builder.output([sliced])
816+
gm = builder.get_graph_module()
817+
818+
result = PropagateSlice().call(gm)
819+
820+
self.assertTrue(result.modified)
821+
822+
slice_nodes = gm.graph.find_nodes(
823+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
824+
)
825+
self.assertEqual(len(slice_nodes), 1)
826+
self.assertEqual(slice_nodes[0].args[0].name, "x")
827+
828+
def test_step_2_through_quantize(self) -> None:
829+
builder = GraphBuilder()
830+
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
831+
quant = builder.call_operator(
832+
exir_ops.edge.cadence.quantize_per_tensor.default,
833+
args=(x, 0.5, 0, 0, 255, torch.uint8),
834+
)
835+
sliced = builder.call_operator(
836+
exir_ops.edge.aten.slice_copy.Tensor,
837+
args=(quant, 0, 0, 4, 2),
838+
)
839+
builder.output([sliced])
840+
gm = builder.get_graph_module()
841+
842+
result = PropagateSlice().call(gm)
843+
844+
self.assertTrue(result.modified)
845+
846+
slice_nodes = gm.graph.find_nodes(
847+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
848+
)
849+
self.assertEqual(len(slice_nodes), 1)
850+
self.assertEqual(slice_nodes[0].args[4], 2)
851+
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 60, 1, 1])
852+
853+
def test_non_batch_dim_slice(self) -> None:
854+
builder = GraphBuilder()
855+
x = builder.placeholder("x", torch.randn(4, 60, 4, 4))
856+
quant = builder.call_operator(
857+
exir_ops.edge.cadence.quantize_per_tensor.default,
858+
args=(x, 0.5, 0, 0, 255, torch.uint8),
859+
)
860+
sliced = builder.call_operator(
861+
exir_ops.edge.aten.slice_copy.Tensor,
862+
args=(quant, 1, 0, 30, 1),
863+
)
864+
builder.output([sliced])
865+
gm = builder.get_graph_module()
866+
867+
result = PropagateSlice().call(gm)
868+
869+
self.assertTrue(result.modified)
870+
871+
slice_nodes = gm.graph.find_nodes(
872+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
873+
)
874+
self.assertEqual(len(slice_nodes), 1)
875+
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 4, 4])
876+
877+
def test_no_swap_when_multi_user(self) -> None:
878+
builder = GraphBuilder()
879+
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
880+
quant = builder.call_operator(
881+
exir_ops.edge.cadence.quantize_per_tensor.default,
882+
args=(x, 0.5, 0, 0, 255, torch.uint8),
883+
)
884+
sliced = builder.call_operator(
885+
exir_ops.edge.aten.slice_copy.Tensor,
886+
args=(quant, 0, 0, 4, 2),
887+
)
888+
builder.output([sliced, quant])
889+
gm = builder.get_graph_module()
890+
891+
result = PropagateSlice().call(gm)
892+
893+
self.assertFalse(result.modified)
894+
895+
def test_no_swap_noop_slice(self) -> None:
896+
builder = GraphBuilder()
897+
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
898+
quant = builder.call_operator(
899+
exir_ops.edge.cadence.quantize_per_tensor.default,
900+
args=(x, 0.5, 0, 0, 255, torch.uint8),
901+
)
902+
sliced = builder.call_operator(
903+
exir_ops.edge.aten.slice_copy.Tensor,
904+
args=(quant, 0, 0, 4, 1),
905+
)
906+
builder.output([sliced])
907+
gm = builder.get_graph_module()
908+
909+
result = PropagateSlice().call(gm)
910+
911+
self.assertFalse(result.modified)
912+
913+
def test_unsupported_parent_not_swapped(self) -> None:
914+
builder = GraphBuilder()
915+
x = builder.placeholder("x", torch.randn(4, 60, 1, 1))
916+
relu = builder.call_operator(
917+
exir_ops.edge.aten.relu.default,
918+
args=(x,),
919+
)
920+
sliced = builder.call_operator(
921+
exir_ops.edge.aten.slice_copy.Tensor,
922+
args=(relu, 0, 0, 4, 2),
923+
)
924+
builder.output([sliced])
925+
gm = builder.get_graph_module()
926+
927+
result = PropagateSlice().call(gm)
928+
929+
self.assertFalse(result.modified)

0 commit comments

Comments
 (0)