|
7 | 7 | import copy |
8 | 8 | import traceback |
9 | 9 | from abc import abstractmethod |
| 10 | +from collections.abc import Collection |
10 | 11 | from typing import Any, List, Optional, Set, Type |
11 | 12 |
|
12 | 13 | import torch |
13 | 14 | from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY |
14 | 15 | from executorch.backends.arm.tosa.mapping import TosaSpecialDtype |
15 | 16 | from executorch.exir.dialects._ops import ops as exir_ops |
16 | 17 | from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue |
17 | | -from torch.fx import GraphModule |
| 18 | +from torch.fx import GraphModule, Node |
18 | 19 | from torch.fx.passes.infra.pass_base import PassResult |
19 | 20 | from torch.utils import _pytree as pytree |
20 | 21 |
|
@@ -191,3 +192,99 @@ def call_scalar(self, value: int | float, meta: NodeMetadata | dict[str, Any]): |
191 | 192 | meta=meta, |
192 | 193 | updated=True, |
193 | 194 | ) |
| 195 | + |
| 196 | + def should_run_pass(self, graph_module: GraphModule) -> bool: |
| 197 | + """Return whether this pass should run on the graph module. |
| 198 | +
|
| 199 | + Subclasses can override this to cheaply skip the pass before |
| 200 | + ``call()`` starts the normal ``ExportPass`` retracing path. |
| 201 | +
|
| 202 | + Args: |
| 203 | + graph_module (GraphModule): The graph module to inspect. |
| 204 | +
|
| 205 | + Returns: |
| 206 | + bool: True when the pass should run. |
| 207 | +
|
| 208 | + """ |
| 209 | + return True |
| 210 | + |
| 211 | + def __call__(self, graph_module: GraphModule) -> PassResult | None: |
| 212 | + self.requires(graph_module) |
| 213 | + if not self.should_run_pass(graph_module): |
| 214 | + self.ensures(graph_module) |
| 215 | + return PassResult(graph_module, False) |
| 216 | + res = self.call(graph_module) |
| 217 | + self.ensures(graph_module) |
| 218 | + return res |
| 219 | + |
| 220 | + |
| 221 | +class ArmOpTargetedPass(ArmPass): |
| 222 | + """Base class for passes that only transform selected operators. |
| 223 | +
|
| 224 | + Subclasses set ``target_ops`` to the call_function targets they can |
| 225 | + transform. If the current graph and nested control-flow subgraphs do not |
| 226 | + contain any target, the pass returns immediately without paying the default |
| 227 | + ExportPass retracing cost. |
| 228 | +
|
| 229 | + Set ``check_allowed_to_transform`` to ``True`` when the target pre-scan |
| 230 | + should also apply ``allowed_to_transform()`` to matching target nodes. This |
| 231 | + is useful for TFA passes whose ``call_operator()`` leaves disallowed target |
| 232 | + nodes unchanged. If all matching targets are disallowed, the pass can |
| 233 | + return before entering the normal ``ExportPass`` path. |
| 234 | +
|
| 235 | + """ |
| 236 | + |
| 237 | + target_ops: Collection[Any] = () |
| 238 | + check_allowed_to_transform = False |
| 239 | + |
| 240 | + def has_target_node(self, graph_module: GraphModule) -> bool: |
| 241 | + """Return whether the graph module tree contains a target node. |
| 242 | +
|
| 243 | + Args: |
| 244 | + graph_module (GraphModule): The graph module tree to inspect. |
| 245 | +
|
| 246 | + Returns: |
| 247 | + bool: True if a matching call_function node is present. |
| 248 | +
|
| 249 | + """ |
| 250 | + visited_graph_modules = set() |
| 251 | + |
| 252 | + def target_node_can_trigger_pass(node: Node) -> bool: |
| 253 | + if not self.check_allowed_to_transform: |
| 254 | + return True |
| 255 | + if self.allowed_to_transform(node.meta): |
| 256 | + return True |
| 257 | + return False |
| 258 | + |
| 259 | + def graph_has_target(module: GraphModule) -> bool: |
| 260 | + if id(module) in visited_graph_modules: |
| 261 | + return False |
| 262 | + visited_graph_modules.add(id(module)) |
| 263 | + |
| 264 | + for target in self.target_ops: |
| 265 | + for node in module.graph.find_nodes( |
| 266 | + op="call_function", |
| 267 | + target=target, |
| 268 | + sort=False, |
| 269 | + ): |
| 270 | + if target_node_can_trigger_pass(node): |
| 271 | + return True |
| 272 | + |
| 273 | + return any( |
| 274 | + isinstance(child, GraphModule) and graph_has_target(child) |
| 275 | + for child in module.children() |
| 276 | + ) |
| 277 | + |
| 278 | + return graph_has_target(graph_module) |
| 279 | + |
| 280 | + def should_run_pass(self, graph_module: GraphModule) -> bool: |
| 281 | + """Return whether this pass has a target node to transform. |
| 282 | +
|
| 283 | + Args: |
| 284 | + graph_module (GraphModule): The graph module tree to inspect. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + bool: True when a matching target node is present. |
| 288 | +
|
| 289 | + """ |
| 290 | + return self.has_target_node(graph_module) |
0 commit comments