@@ -721,15 +721,16 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
721721
722722@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
723723class PropagateSlice (RemoveOrReplacePassInterface ):
724- """Propagate slice_copy before unary element-wise ops when the cost
725- model indicates it reduces total data movement.
724+ """Propagate slice_copy before element-wise ops when the cost model
725+ indicates it reduces total data movement.
726726
727727 Supported ops (extensible via dispatch table):
728- - quantize_per_tensor: element-wise, slice passes through unchanged
729- - dequantize_per_tensor: element-wise, slice passes through unchanged
728+ - quantize_per_tensor: unary element-wise
729+ - dequantize_per_tensor: unary element-wise
730+ - add.Tensor: binary with broadcast — slices non-broadcasting inputs
731+ - mul.Tensor: binary with broadcast — slices non-broadcasting inputs
730732
731- Handles any slice dim and any step size. Runs in the iterative pass
732- loop — chains are handled by repeated application.
733+ Handles any slice dim and any step size.
733734 """
734735
735736 def __init__ (self ) -> None :
@@ -740,16 +741,28 @@ def __init__(self) -> None:
740741 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
741742 exir_ops .edge .cadence .dequantize_per_tensor .default ,
742743 ]
744+ binary_targets = [
745+ exir_ops .edge .aten .add .Tensor ,
746+ exir_ops .edge .aten .mul .Tensor ,
747+ ]
743748 self ._dispatch : dict [
744749 EdgeOpOverload ,
745750 tuple [
746751 Callable [[torch .fx .Node , torch .fx .Node ], bool ],
747752 Callable [[torch .fx .Node , torch .fx .Node ], bool ],
748753 ],
749- ] = {
750- t : (self ._should_swap_elementwise , self ._swap_elementwise_slice )
751- for t in elementwise_targets
752- }
754+ ] = {}
755+ for t in elementwise_targets :
756+ self ._dispatch [t ] = (
757+ self ._should_swap_elementwise ,
758+ self ._swap_elementwise_slice ,
759+ )
760+
761+ for t in binary_targets :
762+ self ._dispatch [t ] = (
763+ self ._should_swap_binary_elementwise ,
764+ self ._swap_binary_elementwise_slice ,
765+ )
753766
754767 @property
755768 def targets (self ) -> list [EdgeOpOverload ]:
@@ -765,19 +778,21 @@ def _should_swap_elementwise(
765778 def _swap_elementwise_slice (
766779 self , op_node : torch .fx .Node , slice_node : torch .fx .Node
767780 ) -> bool :
768- op_input = op_node .args [0 ]
769- assert isinstance (op_input , torch .fx .Node )
781+ op_input = get_arg (op_node , "input" , torch .fx .Node )
770782 graph = slice_node .graph
771783
772- slice_args = slice_node .args [1 :]
784+ slice_dim = get_arg (slice_node , "dim" , int )
785+ slice_start = get_arg (slice_node , "start" )
786+ slice_end = get_arg (slice_node , "end" )
787+ slice_step = get_arg (slice_node , "step" , int )
773788
774789 with graph .inserting_before (op_node ):
775790 new_slice = graph .call_function (
776791 exir_ops .edge .aten .slice_copy .Tensor ,
777- args = (op_input , * slice_args ),
792+ args = (op_input , slice_dim , slice_start , slice_end , slice_step ),
778793 )
779794 new_slice .meta ["val" ] = exir_ops .edge .aten .slice_copy .Tensor (
780- op_input .meta ["val" ], * slice_args
795+ op_input .meta ["val" ], slice_dim , slice_start , slice_end , slice_step
781796 )
782797
783798 new_args = list (op_node .args )
@@ -805,10 +820,68 @@ def _swap_elementwise_slice(
805820 graph .erase_node (op_node )
806821 return True
807822
808- def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
809- parent = node .args [0 ]
810- if not isinstance (parent , torch .fx .Node ):
823+ def _should_swap_binary_elementwise (
824+ self , op_node : torch .fx .Node , slice_node : torch .fx .Node
825+ ) -> bool :
826+ lhs , rhs = op_node .args [0 ], op_node .args [1 ]
827+ assert isinstance (lhs , torch .fx .Node ) and isinstance (rhs , torch .fx .Node )
828+ if lhs .meta ["val" ].shape == rhs .meta ["val" ].shape :
811829 return False
830+ full_size = prod (op_node .meta ["val" ].shape )
831+ sliced_size = prod (slice_node .meta ["val" ].shape )
832+ return sliced_size < full_size
833+
834+ def _swap_binary_elementwise_slice (
835+ self , op_node : torch .fx .Node , slice_node : torch .fx .Node
836+ ) -> bool :
837+ lhs , rhs = op_node .args [0 ], op_node .args [1 ]
838+ assert isinstance (lhs , torch .fx .Node ) and isinstance (rhs , torch .fx .Node )
839+ graph = slice_node .graph
840+
841+ slice_dim = get_arg (slice_node , "dim" , int )
842+ slice_start = get_arg (slice_node , "start" )
843+ slice_end = get_arg (slice_node , "end" )
844+ slice_step = get_arg (slice_node , "step" , int )
845+
846+ output_shape = op_node .meta ["val" ].shape
847+
848+ new_args = list (op_node .args )
849+ with graph .inserting_before (op_node ):
850+ for i , inp in enumerate ([lhs , rhs ]):
851+ if inp .meta ["val" ].shape [slice_dim ] == output_shape [slice_dim ]:
852+ new_slice = graph .call_function (
853+ exir_ops .edge .aten .slice_copy .Tensor ,
854+ args = (inp , slice_dim , slice_start , slice_end , slice_step ),
855+ )
856+ new_slice .meta ["val" ] = exir_ops .edge .aten .slice_copy .Tensor (
857+ inp .meta ["val" ], slice_dim , slice_start , slice_end , slice_step
858+ )
859+ new_args [i ] = new_slice
860+
861+ target = cast (EdgeOpOverload , op_node .target )
862+ new_op = graph .call_function (
863+ target ,
864+ args = tuple (new_args ),
865+ kwargs = op_node .kwargs ,
866+ )
867+ new_op .meta ["val" ] = target (
868+ * [
869+ a .meta ["val" ] if isinstance (a , torch .fx .Node ) else a
870+ for a in new_args
871+ ],
872+ ** {
873+ k : v .meta ["val" ] if isinstance (v , torch .fx .Node ) else v
874+ for k , v in op_node .kwargs .items ()
875+ },
876+ )
877+
878+ slice_node .replace_all_uses_with (new_op )
879+ graph .erase_node (slice_node )
880+ graph .erase_node (op_node )
881+ return True
882+
883+ def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
884+ parent = get_arg (node , "input" , torch .fx .Node )
812885 if len (parent .users ) != 1 :
813886 return False
814887 if not isinstance (parent .target , EdgeOpOverload ):
0 commit comments