Skip to content

Commit a279b72

Browse files
authored
Use torch.Tag.pointwise to auto-discover permutable elementwise ops (#19457)
Differential Revision: D104612850 Pull Request resolved: #19457
1 parent c502916 commit a279b72

1 file changed

Lines changed: 35 additions & 22 deletions

File tree

backends/transforms/remove_permutes_around_elementwise_ops.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ class Subgraph:
4040
default_factory=set
4141
)
4242

43+
# Ops explicitly listed as permutable. This includes non-pointwise ops
44+
# that need special dimension-argument handling (cat, mean, sum, slice)
45+
# and quantize/dequantize ops not tagged as pointwise in ATen.
46+
# In addition to this set, any op tagged with torch.Tag.pointwise is
47+
# automatically considered permutable (see is_node_permutable).
4348
permutable_ops: set[EdgeOpOverload] = {
44-
exir_ops.edge.aten.add.Tensor,
45-
exir_ops.edge.aten.mul.Tensor,
46-
exir_ops.edge.aten.hardtanh.default,
47-
exir_ops.edge.aten.clamp.default,
4849
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
4950
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
50-
# Ops that require special handling.
51+
# Ops that require special handling of dimension arguments.
5152
exir_ops.edge.aten.cat.default,
5253
exir_ops.edge.aten.mean.dim,
5354
exir_ops.edge.aten.sum.dim_IntList,
@@ -67,7 +68,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6768
end_permute = [start_permute.index(i) for i in range(len(start_permute))]
6869

6970
for user in node.users:
70-
if user.target not in self.permutable_ops:
71+
if user.target not in self.permutable_ops and not self._is_pointwise(
72+
user.target
73+
):
7174
continue
7275
# Create a separate subgraph for each user since there may be cases
7376
# where only a portion of the users are permutable.
@@ -159,24 +162,34 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None:
159162
return len(val.shape)
160163
return None
161164

165+
@staticmethod
166+
def _is_pointwise(target) -> bool:
167+
"""Check if a target op is tagged as pointwise in ATen."""
168+
op = getattr(target, "_op", None)
169+
if op is not None and hasattr(op, "tags"):
170+
return torch.Tag.pointwise in op.tags
171+
return False
172+
162173
def is_node_permutable(self, node: torch.fx.Node) -> bool:
163-
if node.target not in self.permutable_ops:
164-
return False
165-
if node.target in (
166-
exir_ops.edge.aten.mean.dim,
167-
exir_ops.edge.aten.sum.dim_IntList,
168-
):
169-
# keepdim should be True.
170-
if len(node.args) >= 3:
171-
if not node.args[2]:
172-
return False
173-
elif "keepdim" in node.kwargs:
174-
if not node.kwargs["keepdim"]:
174+
if node.target in self.permutable_ops:
175+
# Special-case validation for dim-based ops.
176+
if node.target in (
177+
exir_ops.edge.aten.mean.dim,
178+
exir_ops.edge.aten.sum.dim_IntList,
179+
):
180+
# keepdim should be True.
181+
if len(node.args) >= 3:
182+
if not node.args[2]:
183+
return False
184+
elif "keepdim" in node.kwargs:
185+
if not node.kwargs["keepdim"]:
186+
return False
187+
else:
188+
# Default keepdim is False.
175189
return False
176-
else:
177-
# Default keepdim is False.
178-
return False
179-
return True
190+
return True
191+
# Accept any op tagged as pointwise in ATen (elementwise).
192+
return self._is_pointwise(node.target)
180193

181194
def permute_subgraph(self, subgraph: Subgraph) -> None:
182195
# Skip incoming permutes.

0 commit comments

Comments
 (0)