Skip to content

Commit 8464b47

Browse files
authored
Add sum.dim_IntList to RemovePermutesAroundElementwiseOps
Differential Revision: D103272273 Pull Request resolved: #19243
1 parent 50c2a4e commit 8464b47

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

backends/transforms/remove_permutes_around_elementwise_ops.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Subgraph:
5050
# Ops that require special handling.
5151
exir_ops.edge.aten.cat.default,
5252
exir_ops.edge.aten.mean.dim,
53+
exir_ops.edge.aten.sum.dim_IntList,
5354
exir_ops.edge.aten.slice_copy.Tensor,
5455
}
5556

@@ -161,7 +162,10 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None:
161162
def is_node_permutable(self, node: torch.fx.Node) -> bool:
162163
if node.target not in self.permutable_ops:
163164
return False
164-
if node.target == exir_ops.edge.aten.mean.dim:
165+
if node.target in (
166+
exir_ops.edge.aten.mean.dim,
167+
exir_ops.edge.aten.sum.dim_IntList,
168+
):
165169
# keepdim should be True.
166170
if len(node.args) >= 3:
167171
if not node.args[2]:
@@ -236,7 +240,10 @@ def permute_subgraph(self, subgraph: Subgraph) -> None:
236240
for node in subgraph.nodes:
237241
if node.target == exir_ops.edge.aten.cat.default:
238242
self.update_cat(node, subgraph.start_permute)
239-
elif node.target == exir_ops.edge.aten.mean.dim:
243+
elif node.target in (
244+
exir_ops.edge.aten.mean.dim,
245+
exir_ops.edge.aten.sum.dim_IntList,
246+
):
240247
self.update_mean_dim(node, subgraph.start_permute)
241248
elif node.target == exir_ops.edge.aten.slice_copy.Tensor:
242249
self.update_slice_copy(node, subgraph.start_permute)

0 commit comments

Comments
 (0)