Skip to content

Commit a4593f1

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add sum.dim_IntList to RemovePermutesAroundElementwiseOps
Summary: `DecomposeMeanDimPass` decomposes `mean.dim` into `sum.dim_IntList + mul.Tensor`. While `mean.dim` and `mul.Tensor` are both in the `permutable_ops` set of `RemovePermutesAroundElementwiseOps`, `sum.dim_IntList` is not — so the pass cannot traverse through the decomposed chain to find the matching exit permute. This adds `aten.sum.dim_IntList` to `permutable_ops` with the same keepdim validation and dimension-adjustment logic already used for `mean.dim`. This is safe because: - `sum.dim_IntList` has identical dimension semantics to `mean.dim` - `DecomposeMeanDimPass` always calls sum with `keepdim=True` - `mul.Tensor` (the other half of the decomposition) is already handled Reviewed By: 3l1 Differential Revision: D103272273
1 parent 94d2881 commit a4593f1

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

backends/transforms/remove_permutes_around_elementwise_ops.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class Subgraph:
5050
# Ops that require special handling.
5151
exir_ops.edge.aten.cat.default,
5252
exir_ops.edge.aten.mean.dim,
53+
exir_ops.edge.aten.sum.dim_IntList,
5354
exir_ops.edge.aten.slice_copy.Tensor,
5455
}
5556

@@ -161,7 +162,10 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None:
161162
def is_node_permutable(self, node: torch.fx.Node) -> bool:
162163
if node.target not in self.permutable_ops:
163164
return False
164-
if node.target == exir_ops.edge.aten.mean.dim:
165+
if node.target in (
166+
exir_ops.edge.aten.mean.dim,
167+
exir_ops.edge.aten.sum.dim_IntList,
168+
):
165169
# keepdim should be True.
166170
if len(node.args) >= 3:
167171
if not node.args[2]:
@@ -236,7 +240,10 @@ def permute_subgraph(self, subgraph: Subgraph) -> None:
236240
for node in subgraph.nodes:
237241
if node.target == exir_ops.edge.aten.cat.default:
238242
self.update_cat(node, subgraph.start_permute)
239-
elif node.target == exir_ops.edge.aten.mean.dim:
243+
elif node.target in (
244+
exir_ops.edge.aten.mean.dim,
245+
exir_ops.edge.aten.sum.dim_IntList,
246+
):
240247
self.update_mean_dim(node, subgraph.start_permute)
241248
elif node.target == exir_ops.edge.aten.slice_copy.Tensor:
242249
self.update_slice_copy(node, subgraph.start_permute)

0 commit comments

Comments
 (0)