@@ -50,6 +50,7 @@ class Subgraph:
5050 # Ops that require special handling.
5151 exir_ops .edge .aten .cat .default ,
5252 exir_ops .edge .aten .mean .dim ,
53+ exir_ops .edge .aten .sum .dim_IntList ,
5354 exir_ops .edge .aten .slice_copy .Tensor ,
5455 }
5556
@@ -161,7 +162,10 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None:
161162 def is_node_permutable (self , node : torch .fx .Node ) -> bool :
162163 if node .target not in self .permutable_ops :
163164 return False
164- if node .target == exir_ops .edge .aten .mean .dim :
165+ if node .target in (
166+ exir_ops .edge .aten .mean .dim ,
167+ exir_ops .edge .aten .sum .dim_IntList ,
168+ ):
165169 # keepdim should be True.
166170 if len (node .args ) >= 3 :
167171 if not node .args [2 ]:
@@ -236,7 +240,10 @@ def permute_subgraph(self, subgraph: Subgraph) -> None:
236240 for node in subgraph .nodes :
237241 if node .target == exir_ops .edge .aten .cat .default :
238242 self .update_cat (node , subgraph .start_permute )
239- elif node .target == exir_ops .edge .aten .mean .dim :
243+ elif node .target in (
244+ exir_ops .edge .aten .mean .dim ,
245+ exir_ops .edge .aten .sum .dim_IntList ,
246+ ):
240247 self .update_mean_dim (node , subgraph .start_permute )
241248 elif node .target == exir_ops .edge .aten .slice_copy .Tensor :
242249 self .update_slice_copy (node , subgraph .start_permute )
0 commit comments