|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.exir.dialects.edge._ops import EdgeOpOverload |
| 11 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 12 | +from torch.fx import GraphModule |
| 13 | +from torch.fx.node import Node |
| 14 | + |
| 15 | + |
| 16 | +class CSEPass(ExportPass): |
| 17 | + """ |
| 18 | + Common Subexpression Elimination using structural hashing (value numbering). |
| 19 | +
|
| 20 | + Deduplicates operations with identical computation structure, replacing |
| 21 | + redundant computations with references to previously computed results. |
| 22 | +
|
| 23 | + Uses recursive structural keys: two nodes are equivalent if they have the |
| 24 | + same op and their inputs are structurally equivalent. This naturally handles |
| 25 | + chains like item(select(select(x, 0, a), 0, b)) without special cases. |
| 26 | +
|
| 27 | + Safety is determined automatically via op schema introspection: |
| 28 | + - For OpOverload targets (aten ops): checks _schema for mutating arguments |
| 29 | + - For operator.* targets (SymInt arithmetic): always safe |
| 30 | + - A small denylist covers non-deterministic ops (rand, dropout, etc.) |
| 31 | + """ |
| 32 | + |
| 33 | + # Ops that are pure (no mutation per schema) but non-deterministic: |
| 34 | + # same inputs can produce different outputs. |
| 35 | + UNSAFE_OPS = frozenset( |
| 36 | + [ |
| 37 | + "aten::rand", |
| 38 | + "aten::rand_like", |
| 39 | + "aten::randn", |
| 40 | + "aten::randn_like", |
| 41 | + "aten::randint", |
| 42 | + "aten::randint_like", |
| 43 | + "aten::randperm", |
| 44 | + "aten::bernoulli", |
| 45 | + "aten::dropout", |
| 46 | + "aten::native_dropout", |
| 47 | + "aten::multinomial", |
| 48 | + "aten::normal", |
| 49 | + "aten::uniform", |
| 50 | + ] |
| 51 | + ) |
| 52 | + |
| 53 | + def _is_safe_target(self, target) -> bool: |
| 54 | + """ |
| 55 | + Determine if an op target is safe for CSE. |
| 56 | +
|
| 57 | + Uses schema introspection for OpOverload targets: if no argument |
| 58 | + has alias_info with is_write=True, the op doesn't mutate and is |
| 59 | + safe (unless it's non-deterministic). |
| 60 | +
|
| 61 | + Python operator.* targets are always safe (pure scalar arithmetic). |
| 62 | + """ |
| 63 | + # Python operator module functions are always pure and deterministic |
| 64 | + if getattr(target, "__module__", None) == "_operator": |
| 65 | + return True |
| 66 | + |
| 67 | + # EdgeOpOverload targets (edge dialect ops) |
| 68 | + if isinstance(target, (torch._ops.OpOverload, EdgeOpOverload)): |
| 69 | + schema_name = target._schema.name |
| 70 | + |
| 71 | + # Only trust schema introspection for aten:: ops. |
| 72 | + # Custom op schemas (mlx::, torchao::, etc.) may not accurately |
| 73 | + # annotate mutation or side effects — default to unsafe. |
| 74 | + if not schema_name.startswith("aten::"): |
| 75 | + return False |
| 76 | + |
| 77 | + # Check denylist for non-deterministic/side-effecting ops |
| 78 | + if schema_name in self.UNSAFE_OPS: |
| 79 | + return False |
| 80 | + |
| 81 | + # Check schema for mutating arguments |
| 82 | + for arg in target._schema.arguments: |
| 83 | + if arg.alias_info is not None and arg.alias_info.is_write: |
| 84 | + return False |
| 85 | + |
| 86 | + return True |
| 87 | + |
| 88 | + return False |
| 89 | + |
| 90 | + def call(self, graph_module: GraphModule) -> PassResult: |
| 91 | + graph = graph_module.graph |
| 92 | + |
| 93 | + # Discover graph output nodes — includes buffer mutation outputs |
| 94 | + # (e.g. index_copy for KV cache). These must never be deduplicated |
| 95 | + # because the graph signature references them by name for writeback. |
| 96 | + output_node = next(n for n in graph.nodes if n.op == "output") |
| 97 | + self._output_nodes: set[Node] = set() |
| 98 | + for arg in output_node.args[0]: |
| 99 | + if isinstance(arg, Node): |
| 100 | + self._output_nodes.add(arg) |
| 101 | + |
| 102 | + self._vn_cache: dict[Node, int] = {} # Node → value number |
| 103 | + self._safe_cache: dict[Any, bool] = {} # Cache for _is_safe_target |
| 104 | + self._sig_to_vn: dict[Any, int] = {} # flat signature → value number |
| 105 | + self._vn_to_node: dict[int, Node] = {} # value number → canonical node |
| 106 | + self._next_vn = 0 |
| 107 | + modified = False |
| 108 | + |
| 109 | + for node in list(graph.nodes): |
| 110 | + vn = self._value_number(node) |
| 111 | + |
| 112 | + if vn in self._vn_to_node: |
| 113 | + canonical = self._vn_to_node[vn] |
| 114 | + if canonical is not node: |
| 115 | + node.replace_all_uses_with(canonical) |
| 116 | + graph.erase_node(node) |
| 117 | + modified = True |
| 118 | + else: |
| 119 | + self._vn_to_node[vn] = node |
| 120 | + |
| 121 | + if modified: |
| 122 | + graph.eliminate_dead_code() |
| 123 | + graph.lint() |
| 124 | + |
| 125 | + return PassResult(graph_module, modified) |
| 126 | + |
| 127 | + def _is_safe(self, target) -> bool: |
| 128 | + """Cached version of _is_safe_target.""" |
| 129 | + tid = id(target) |
| 130 | + if tid not in self._safe_cache: |
| 131 | + self._safe_cache[tid] = self._is_safe_target(target) |
| 132 | + return self._safe_cache[tid] |
| 133 | + |
| 134 | + def _new_vn(self) -> int: |
| 135 | + """Allocate a fresh unique value number.""" |
| 136 | + vn = self._next_vn |
| 137 | + self._next_vn += 1 |
| 138 | + return vn |
| 139 | + |
| 140 | + def _value_number(self, node: Node) -> int: |
| 141 | + """ |
| 142 | + Assign an integer value number to a node (global value numbering). |
| 143 | +
|
| 144 | + Two nodes with the same value number are structurally equivalent |
| 145 | + and can be deduplicated. All signature tuples are flat (contain |
| 146 | + only ints and scalars), so hashing is O(n_args) not O(graph_depth). |
| 147 | + """ |
| 148 | + if node in self._vn_cache: |
| 149 | + return self._vn_cache[node] |
| 150 | + |
| 151 | + if node.op != "call_function": |
| 152 | + vn = self._new_vn() |
| 153 | + elif node in self._output_nodes: |
| 154 | + # Graph output node (includes buffer mutations like index_copy). |
| 155 | + # Must keep unique — graph signature references this node by name. |
| 156 | + vn = self._new_vn() |
| 157 | + elif not self._is_safe(node.target): |
| 158 | + vn = self._new_vn() |
| 159 | + else: |
| 160 | + try: |
| 161 | + args_sig = tuple(self._make_hashable(a) for a in node.args) |
| 162 | + kwargs_sig = tuple( |
| 163 | + sorted((k, self._make_hashable(v)) for k, v in node.kwargs.items()) |
| 164 | + ) |
| 165 | + sig = (node.target, args_sig, kwargs_sig) |
| 166 | + |
| 167 | + if sig in self._sig_to_vn: |
| 168 | + vn = self._sig_to_vn[sig] |
| 169 | + else: |
| 170 | + vn = self._new_vn() |
| 171 | + self._sig_to_vn[sig] = vn |
| 172 | + except TypeError: |
| 173 | + vn = self._new_vn() |
| 174 | + |
| 175 | + self._vn_cache[node] = vn |
| 176 | + return vn |
| 177 | + |
| 178 | + def _make_hashable(self, obj) -> Any: |
| 179 | + """Convert args/kwargs to a hashable form. |
| 180 | +
|
| 181 | + For Node args, returns the integer value number — keeping |
| 182 | + all signature tuples flat and O(1) to hash. |
| 183 | + """ |
| 184 | + if isinstance(obj, Node): |
| 185 | + return self._value_number(obj) |
| 186 | + elif isinstance(obj, (int, float, str, bool, type(None))): |
| 187 | + return obj |
| 188 | + elif isinstance(obj, (list, tuple)): |
| 189 | + return tuple(self._make_hashable(x) for x in obj) |
| 190 | + elif isinstance(obj, dict): |
| 191 | + return tuple(sorted((k, self._make_hashable(v)) for k, v in obj.items())) |
| 192 | + elif isinstance(obj, torch.dtype): |
| 193 | + return obj |
| 194 | + elif isinstance(obj, torch.device): |
| 195 | + return str(obj) |
| 196 | + elif isinstance(obj, torch.layout): |
| 197 | + return str(obj) |
| 198 | + elif isinstance(obj, torch.memory_format): |
| 199 | + return str(obj) |
| 200 | + else: |
| 201 | + raise TypeError(f"Cannot make {type(obj)} hashable") |
0 commit comments