@@ -40,14 +40,15 @@ class Subgraph:
4040 default_factory = set
4141 )
4242
43+ # Ops explicitly listed as permutable. This includes non-pointwise ops
44+ # that need special dimension-argument handling (cat, mean, sum, slice)
45+ # and quantize/dequantize ops not tagged as pointwise in ATen.
46+ # In addition to this set, any op tagged with torch.Tag.pointwise is
47+ # automatically considered permutable (see is_node_permutable).
4348 permutable_ops : set [EdgeOpOverload ] = {
44- exir_ops .edge .aten .add .Tensor ,
45- exir_ops .edge .aten .mul .Tensor ,
46- exir_ops .edge .aten .hardtanh .default ,
47- exir_ops .edge .aten .clamp .default ,
4849 exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
4950 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
50- # Ops that require special handling.
51+ # Ops that require special handling of dimension arguments .
5152 exir_ops .edge .aten .cat .default ,
5253 exir_ops .edge .aten .mean .dim ,
5354 exir_ops .edge .aten .sum .dim_IntList ,
@@ -67,7 +68,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
6768 end_permute = [start_permute .index (i ) for i in range (len (start_permute ))]
6869
6970 for user in node .users :
70- if user .target not in self .permutable_ops :
71+ if user .target not in self .permutable_ops and not self ._is_pointwise (
72+ user .target
73+ ):
7174 continue
7275 # Create a separate subgraph for each user since there may be cases
7376 # where only a portion of the users are permutable.
@@ -159,24 +162,34 @@ def _get_node_rank(self, node: torch.fx.Node) -> int | None:
159162 return len (val .shape )
160163 return None
161164
165+ @staticmethod
166+ def _is_pointwise (target ) -> bool :
167+ """Check if a target op is tagged as pointwise in ATen."""
168+ op = getattr (target , "_op" , None )
169+ if op is not None and hasattr (op , "tags" ):
170+ return torch .Tag .pointwise in op .tags
171+ return False
172+
162173 def is_node_permutable (self , node : torch .fx .Node ) -> bool :
163- if node .target not in self .permutable_ops :
164- return False
165- if node .target in (
166- exir_ops .edge .aten .mean .dim ,
167- exir_ops .edge .aten .sum .dim_IntList ,
168- ):
169- # keepdim should be True.
170- if len (node .args ) >= 3 :
171- if not node .args [2 ]:
172- return False
173- elif "keepdim" in node .kwargs :
174- if not node .kwargs ["keepdim" ]:
174+ if node .target in self .permutable_ops :
175+ # Special-case validation for dim-based ops.
176+ if node .target in (
177+ exir_ops .edge .aten .mean .dim ,
178+ exir_ops .edge .aten .sum .dim_IntList ,
179+ ):
180+ # keepdim should be True.
181+ if len (node .args ) >= 3 :
182+ if not node .args [2 ]:
183+ return False
184+ elif "keepdim" in node .kwargs :
185+ if not node .kwargs ["keepdim" ]:
186+ return False
187+ else :
188+ # Default keepdim is False.
175189 return False
176- else :
177- # Default keepdim is False.
178- return False
179- return True
190+ return True
191+ # Accept any op tagged as pointwise in ATen (elementwise).
192+ return self ._is_pointwise (node .target )
180193
181194 def permute_subgraph (self , subgraph : Subgraph ) -> None :
182195 # Skip incoming permutes.
0 commit comments