Skip to content

Commit 180edd3

Browse files
authored
Reorder slice before binary broadcast ops (#19346)
Differential Revision: D103563327 Pull Request resolved: #19346
1 parent d858cd9 commit 180edd3

2 files changed

Lines changed: 188 additions & 18 deletions

File tree

backends/cadence/aot/reorder_ops.py

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -721,15 +721,16 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
721721

722722
@register_cadence_pass(CadencePassAttribute(opt_level=1))
723723
class PropagateSlice(RemoveOrReplacePassInterface):
724-
"""Propagate slice_copy before unary element-wise ops when the cost
725-
model indicates it reduces total data movement.
724+
"""Propagate slice_copy before element-wise ops when the cost model
725+
indicates it reduces total data movement.
726726
727727
Supported ops (extensible via dispatch table):
728-
- quantize_per_tensor: element-wise, slice passes through unchanged
729-
- dequantize_per_tensor: element-wise, slice passes through unchanged
728+
- quantize_per_tensor: unary element-wise
729+
- dequantize_per_tensor: unary element-wise
730+
- add.Tensor: binary with broadcast — slices non-broadcasting inputs
731+
- mul.Tensor: binary with broadcast — slices non-broadcasting inputs
730732
731-
Handles any slice dim and any step size. Runs in the iterative pass
732-
loop — chains are handled by repeated application.
733+
Handles any slice dim and any step size.
733734
"""
734735

735736
def __init__(self) -> None:
@@ -740,16 +741,28 @@ def __init__(self) -> None:
740741
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
741742
exir_ops.edge.cadence.dequantize_per_tensor.default,
742743
]
744+
binary_targets = [
745+
exir_ops.edge.aten.add.Tensor,
746+
exir_ops.edge.aten.mul.Tensor,
747+
]
743748
self._dispatch: dict[
744749
EdgeOpOverload,
745750
tuple[
746751
Callable[[torch.fx.Node, torch.fx.Node], bool],
747752
Callable[[torch.fx.Node, torch.fx.Node], bool],
748753
],
749-
] = {
750-
t: (self._should_swap_elementwise, self._swap_elementwise_slice)
751-
for t in elementwise_targets
752-
}
754+
] = {}
755+
for t in elementwise_targets:
756+
self._dispatch[t] = (
757+
self._should_swap_elementwise,
758+
self._swap_elementwise_slice,
759+
)
760+
761+
for t in binary_targets:
762+
self._dispatch[t] = (
763+
self._should_swap_binary_elementwise,
764+
self._swap_binary_elementwise_slice,
765+
)
753766

754767
@property
755768
def targets(self) -> list[EdgeOpOverload]:
@@ -765,19 +778,21 @@ def _should_swap_elementwise(
765778
def _swap_elementwise_slice(
766779
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
767780
) -> bool:
768-
op_input = op_node.args[0]
769-
assert isinstance(op_input, torch.fx.Node)
781+
op_input = get_arg(op_node, "input", torch.fx.Node)
770782
graph = slice_node.graph
771783

772-
slice_args = slice_node.args[1:]
784+
slice_dim = get_arg(slice_node, "dim", int)
785+
slice_start = get_arg(slice_node, "start")
786+
slice_end = get_arg(slice_node, "end")
787+
slice_step = get_arg(slice_node, "step", int)
773788

774789
with graph.inserting_before(op_node):
775790
new_slice = graph.call_function(
776791
exir_ops.edge.aten.slice_copy.Tensor,
777-
args=(op_input, *slice_args),
792+
args=(op_input, slice_dim, slice_start, slice_end, slice_step),
778793
)
779794
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
780-
op_input.meta["val"], *slice_args
795+
op_input.meta["val"], slice_dim, slice_start, slice_end, slice_step
781796
)
782797

783798
new_args = list(op_node.args)
@@ -805,10 +820,68 @@ def _swap_elementwise_slice(
805820
graph.erase_node(op_node)
806821
return True
807822

808-
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
809-
parent = node.args[0]
810-
if not isinstance(parent, torch.fx.Node):
823+
def _should_swap_binary_elementwise(
824+
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
825+
) -> bool:
826+
lhs, rhs = op_node.args[0], op_node.args[1]
827+
assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node)
828+
if lhs.meta["val"].shape == rhs.meta["val"].shape:
811829
return False
830+
full_size = prod(op_node.meta["val"].shape)
831+
sliced_size = prod(slice_node.meta["val"].shape)
832+
return sliced_size < full_size
833+
834+
def _swap_binary_elementwise_slice(
835+
self, op_node: torch.fx.Node, slice_node: torch.fx.Node
836+
) -> bool:
837+
lhs, rhs = op_node.args[0], op_node.args[1]
838+
assert isinstance(lhs, torch.fx.Node) and isinstance(rhs, torch.fx.Node)
839+
graph = slice_node.graph
840+
841+
slice_dim = get_arg(slice_node, "dim", int)
842+
slice_start = get_arg(slice_node, "start")
843+
slice_end = get_arg(slice_node, "end")
844+
slice_step = get_arg(slice_node, "step", int)
845+
846+
output_shape = op_node.meta["val"].shape
847+
848+
new_args = list(op_node.args)
849+
with graph.inserting_before(op_node):
850+
for i, inp in enumerate([lhs, rhs]):
851+
if inp.meta["val"].shape[slice_dim] == output_shape[slice_dim]:
852+
new_slice = graph.call_function(
853+
exir_ops.edge.aten.slice_copy.Tensor,
854+
args=(inp, slice_dim, slice_start, slice_end, slice_step),
855+
)
856+
new_slice.meta["val"] = exir_ops.edge.aten.slice_copy.Tensor(
857+
inp.meta["val"], slice_dim, slice_start, slice_end, slice_step
858+
)
859+
new_args[i] = new_slice
860+
861+
target = cast(EdgeOpOverload, op_node.target)
862+
new_op = graph.call_function(
863+
target,
864+
args=tuple(new_args),
865+
kwargs=op_node.kwargs,
866+
)
867+
new_op.meta["val"] = target(
868+
*[
869+
a.meta["val"] if isinstance(a, torch.fx.Node) else a
870+
for a in new_args
871+
],
872+
**{
873+
k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
874+
for k, v in op_node.kwargs.items()
875+
},
876+
)
877+
878+
slice_node.replace_all_uses_with(new_op)
879+
graph.erase_node(slice_node)
880+
graph.erase_node(op_node)
881+
return True
882+
883+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
884+
parent = get_arg(node, "input", torch.fx.Node)
812885
if len(parent.users) != 1:
813886
return False
814887
if not isinstance(parent.target, EdgeOpOverload):

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,100 @@ def test_unsupported_parent_not_swapped(self) -> None:
927927
result = PropagateSlice().call(gm)
928928

929929
self.assertFalse(result.modified)
930+
931+
def test_swap_broadcast_mul_slice_on_broadcast_dim(self) -> None:
932+
"""[1,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=0, step=2)
933+
Only the [4,1,1,1] input should be sliced."""
934+
builder = GraphBuilder()
935+
a = builder.placeholder("a", torch.randn(1, 60, 1, 1))
936+
b = builder.placeholder("b", torch.randn(4, 1, 1, 1))
937+
mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b))
938+
sliced = builder.call_operator(
939+
exir_ops.edge.aten.slice_copy.Tensor,
940+
args=(mul, 0, 0, 4, 2),
941+
)
942+
builder.output([sliced])
943+
gm = builder.get_graph_module()
944+
945+
result = PropagateSlice().call(gm)
946+
947+
self.assertTrue(result.modified)
948+
949+
slice_nodes = gm.graph.find_nodes(
950+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
951+
)
952+
self.assertEqual(len(slice_nodes), 1)
953+
self.assertEqual(slice_nodes[0].args[0].name, "b")
954+
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [2, 1, 1, 1])
955+
956+
mul_nodes = gm.graph.find_nodes(
957+
op="call_function", target=exir_ops.edge.aten.mul.Tensor
958+
)
959+
self.assertEqual(len(mul_nodes), 1)
960+
self.assertEqual(list(mul_nodes[0].meta["val"].shape), [2, 60, 1, 1])
961+
962+
def test_swap_broadcast_add_lhs_broadcasts(self) -> None:
963+
"""[1,60,4,4] + [4,60,4,4] → [4,60,4,4] → slice(dim=0, step=2)
964+
Only the [4,60,4,4] (rhs) should be sliced."""
965+
builder = GraphBuilder()
966+
a = builder.placeholder("a", torch.randn(1, 60, 4, 4))
967+
b = builder.placeholder("b", torch.randn(4, 60, 4, 4))
968+
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b))
969+
sliced = builder.call_operator(
970+
exir_ops.edge.aten.slice_copy.Tensor,
971+
args=(add, 0, 0, 4, 2),
972+
)
973+
builder.output([sliced])
974+
gm = builder.get_graph_module()
975+
976+
result = PropagateSlice().call(gm)
977+
978+
self.assertTrue(result.modified)
979+
980+
slice_nodes = gm.graph.find_nodes(
981+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
982+
)
983+
self.assertEqual(len(slice_nodes), 1)
984+
self.assertEqual(slice_nodes[0].args[0].name, "b")
985+
986+
def test_swap_broadcast_mul_slice_on_non_broadcast_dim(self) -> None:
987+
"""[4,60,1,1] * [4,1,1,1] → [4,60,1,1] → slice(dim=1, start=0, end=30)
988+
Only the [4,60,1,1] (lhs) should be sliced since rhs has dim1=1."""
989+
builder = GraphBuilder()
990+
a = builder.placeholder("a", torch.randn(4, 60, 1, 1))
991+
b = builder.placeholder("b", torch.randn(4, 1, 1, 1))
992+
mul = builder.call_operator(exir_ops.edge.aten.mul.Tensor, args=(a, b))
993+
sliced = builder.call_operator(
994+
exir_ops.edge.aten.slice_copy.Tensor,
995+
args=(mul, 1, 0, 30, 1),
996+
)
997+
builder.output([sliced])
998+
gm = builder.get_graph_module()
999+
1000+
result = PropagateSlice().call(gm)
1001+
1002+
self.assertTrue(result.modified)
1003+
1004+
slice_nodes = gm.graph.find_nodes(
1005+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1006+
)
1007+
self.assertEqual(len(slice_nodes), 1)
1008+
self.assertEqual(slice_nodes[0].args[0].name, "a")
1009+
self.assertEqual(list(slice_nodes[0].meta["val"].shape), [4, 30, 1, 1])
1010+
1011+
def test_no_swap_binary_same_shape(self) -> None:
1012+
"""Same-shape binary ops are not swapped (no broadcast)."""
1013+
builder = GraphBuilder()
1014+
a = builder.placeholder("a", torch.randn(4, 60, 4, 4))
1015+
b = builder.placeholder("b", torch.randn(4, 60, 4, 4))
1016+
add = builder.call_operator(exir_ops.edge.aten.add.Tensor, args=(a, b))
1017+
sliced = builder.call_operator(
1018+
exir_ops.edge.aten.slice_copy.Tensor,
1019+
args=(add, 0, 0, 4, 2),
1020+
)
1021+
builder.output([sliced])
1022+
gm = builder.get_graph_module()
1023+
1024+
result = PropagateSlice().call(gm)
1025+
1026+
self.assertFalse(result.modified)

0 commit comments

Comments
 (0)