diff --git a/exir/passes/BUCK b/exir/passes/BUCK index 9a387e9f96f..f2fff1f886a 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -418,6 +418,18 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "cse_pass", + srcs = [ + "cse_pass.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects/edge:lib", + ], +) + fbcode_target(_kind = runtime.python_library, name = "convert_constant_dim_order_pass", srcs = [ diff --git a/exir/passes/cse_pass.py b/exir/passes/cse_pass.py new file mode 100644 index 00000000000..1ada3a385b5 --- /dev/null +++ b/exir/passes/cse_pass.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule +from torch.fx.node import Node + + +class CSEPass(ExportPass): + """ + Common Subexpression Elimination using structural hashing (value numbering). + + Deduplicates operations with identical computation structure, replacing + redundant computations with references to previously computed results. + + Uses recursive structural keys: two nodes are equivalent if they have the + same op and their inputs are structurally equivalent. This naturally handles + chains like item(select(select(x, 0, a), 0, b)) without special cases. + + Safety is determined automatically via op schema introspection: + - For OpOverload targets (aten ops): checks _schema for mutating arguments + - For operator.* targets (SymInt arithmetic): always safe + - A small denylist covers non-deterministic ops (rand, dropout, etc.) + """ + + # Ops that are pure (no mutation per schema) but non-deterministic: + # same inputs can produce different outputs. + UNSAFE_OPS = frozenset( + [ + "aten::rand", + "aten::rand_like", + "aten::randn", + "aten::randn_like", + "aten::randint", + "aten::randint_like", + "aten::randperm", + "aten::bernoulli", + "aten::dropout", + "aten::native_dropout", + "aten::multinomial", + "aten::normal", + "aten::uniform", + ] + ) + + def _is_safe_target(self, target) -> bool: + """ + Determine if an op target is safe for CSE. + + Uses schema introspection for OpOverload targets: if no argument + has alias_info with is_write=True, the op doesn't mutate and is + safe (unless it's non-deterministic). + + Python operator.* targets are always safe (pure scalar arithmetic). + """ + # Python operator module functions are always pure and deterministic + if getattr(target, "__module__", None) == "_operator": + return True + + # EdgeOpOverload targets (edge dialect ops) + if isinstance(target, (torch._ops.OpOverload, EdgeOpOverload)): + schema_name = target._schema.name + + # Only trust schema introspection for aten:: ops. + # Custom op schemas (mlx::, torchao::, etc.) may not accurately + # annotate mutation or side effects — default to unsafe. + if not schema_name.startswith("aten::"): + return False + + # Check denylist for non-deterministic/side-effecting ops + if schema_name in self.UNSAFE_OPS: + return False + + # Check schema for mutating arguments + for arg in target._schema.arguments: + if arg.alias_info is not None and arg.alias_info.is_write: + return False + + return True + + return False + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + + # Discover graph output nodes — includes buffer mutation outputs + # (e.g. index_copy for KV cache). These must never be deduplicated + # because the graph signature references them by name for writeback. + output_node = next(n for n in graph.nodes if n.op == "output") + self._output_nodes: set[Node] = set() + for arg in output_node.args[0]: + if isinstance(arg, Node): + self._output_nodes.add(arg) + + self._vn_cache: dict[Node, int] = {} # Node → value number + self._safe_cache: dict[Any, bool] = {} # Cache for _is_safe_target + self._sig_to_vn: dict[Any, int] = {} # flat signature → value number + self._vn_to_node: dict[int, Node] = {} # value number → canonical node + self._next_vn = 0 + modified = False + + for node in list(graph.nodes): + vn = self._value_number(node) + + if vn in self._vn_to_node: + canonical = self._vn_to_node[vn] + if canonical is not node: + node.replace_all_uses_with(canonical) + graph.erase_node(node) + modified = True + else: + self._vn_to_node[vn] = node + + if modified: + graph.eliminate_dead_code() + graph.lint() + + return PassResult(graph_module, modified) + + def _is_safe(self, target) -> bool: + """Cached version of _is_safe_target.""" + tid = id(target) + if tid not in self._safe_cache: + self._safe_cache[tid] = self._is_safe_target(target) + return self._safe_cache[tid] + + def _new_vn(self) -> int: + """Allocate a fresh unique value number.""" + vn = self._next_vn + self._next_vn += 1 + return vn + + def _value_number(self, node: Node) -> int: + """ + Assign an integer value number to a node (global value numbering). + + Two nodes with the same value number are structurally equivalent + and can be deduplicated. All signature tuples are flat (contain + only ints and scalars), so hashing is O(n_args) not O(graph_depth). + """ + if node in self._vn_cache: + return self._vn_cache[node] + + if node.op != "call_function": + vn = self._new_vn() + elif node in self._output_nodes: + # Graph output node (includes buffer mutations like index_copy). + # Must keep unique — graph signature references this node by name. + vn = self._new_vn() + elif not self._is_safe(node.target): + vn = self._new_vn() + else: + try: + args_sig = tuple(self._make_hashable(a) for a in node.args) + kwargs_sig = tuple( + sorted((k, self._make_hashable(v)) for k, v in node.kwargs.items()) + ) + sig = (node.target, args_sig, kwargs_sig) + + if sig in self._sig_to_vn: + vn = self._sig_to_vn[sig] + else: + vn = self._new_vn() + self._sig_to_vn[sig] = vn + except TypeError: + vn = self._new_vn() + + self._vn_cache[node] = vn + return vn + + def _make_hashable(self, obj) -> Any: + """Convert args/kwargs to a hashable form. + + For Node args, returns the integer value number — keeping + all signature tuples flat and O(1) to hash. + """ + if isinstance(obj, Node): + return self._value_number(obj) + elif isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, (list, tuple)): + return tuple(self._make_hashable(x) for x in obj) + elif isinstance(obj, dict): + return tuple(sorted((k, self._make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, torch.dtype): + return obj + elif isinstance(obj, torch.device): + return str(obj) + elif isinstance(obj, torch.layout): + return str(obj) + elif isinstance(obj, torch.memory_format): + return str(obj) + else: + raise TypeError(f"Cannot make {type(obj)} hashable") diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 3adce886208..01f579e0479 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -224,6 +224,7 @@ python_unittest( "//executorch/exir/dialects/edge:lib", "//executorch/exir/emit:lib", "//executorch/exir/passes:constant_prop_pass", + "//executorch/exir/passes:cse_pass", "//executorch/exir/passes:debug_handle_generator_pass", "//executorch/exir/passes:insert_write_back_for_buffers_pass", "//executorch/exir/passes:lib", diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 452f9694a8d..6d4fbd37107 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -53,6 +53,7 @@ ToOutVarPass, ) from executorch.exir.passes.constant_prop_pass import constant_prop_pass +from executorch.exir.passes.cse_pass import CSEPass from executorch.exir.passes.debug_handle_generator_pass import ( DebugHandleGeneratorPass, generate_missing_debug_handles, @@ -2499,3 +2500,93 @@ def test_convert_constant_dim_order_to_contiguous(self): modified_const.is_contiguous(), f"Constant should be contiguous after pass, got strides {modified_const.stride()}", ) + + +class TestCSEPass(unittest.TestCase): + """Tests for Common Subexpression Elimination pass.""" + + @staticmethod + def _to_edge_gm(module, example_inputs): + ep = torch.export.export(module, example_inputs, strict=False) + edge = to_edge(ep) + return edge.exported_program().graph_module + + @staticmethod + def _count_ops(gm, target): + return sum( + 1 for n in gm.graph.nodes if n.op == "call_function" and n.target == target + ) + + def test_duplicate_unary_ops_deduplicated(self): + """Two identical neg(x) ops should be merged into one.""" + + class M(torch.nn.Module): + def forward(self, x): + a = torch.neg(x) + b = torch.neg(x) + return a + b + + gm = self._to_edge_gm(M(), (torch.randn(4, 4),)) + target = exir_ops.edge.aten.neg.default + before = self._count_ops(gm, target) + + if before < 2: + self.skipTest("Export already deduplicated neg ops") + + result = CSEPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(self._count_ops(result.graph_module, target), 1) + + def test_different_ops_not_merged(self): + """neg(x) and abs(x) should not be merged.""" + + class M(torch.nn.Module): + def forward(self, x): + return torch.neg(x) + torch.abs(x) + + gm = self._to_edge_gm(M(), (torch.randn(4, 4),)) + result = CSEPass()(gm) + self.assertFalse(result.modified) + + def test_same_op_different_inputs_not_merged(self): + """neg(x) and neg(y) should not be merged.""" + + class M(torch.nn.Module): + def forward(self, x, y): + return torch.neg(x) + torch.neg(y) + + gm = self._to_edge_gm(M(), (torch.randn(4, 4), torch.randn(4, 4))) + result = CSEPass()(gm) + self.assertFalse(result.modified) + + def test_noop_when_no_duplicates(self): + class M(torch.nn.Module): + def forward(self, x): + return x + 1 + + gm = self._to_edge_gm(M(), (torch.randn(4, 4),)) + result = CSEPass()(gm) + self.assertFalse(result.modified) + + def test_duplicate_chains_deduplicated(self): + """Duplicate multi-op chains should be merged via structural hashing.""" + + class M(torch.nn.Module): + def forward(self, x): + a = torch.neg(torch.abs(x)) + b = torch.neg(torch.abs(x)) + return a + b + + gm = self._to_edge_gm(M(), (torch.randn(4, 4),)) + neg_target = exir_ops.edge.aten.neg.default + abs_target = exir_ops.edge.aten.abs.default + + if self._count_ops(gm, neg_target) < 2: + self.skipTest("Export already deduplicated chains") + + result = CSEPass()(gm) + + self.assertTrue(result.modified) + self.assertEqual(self._count_ops(result.graph_module, neg_target), 1) + self.assertEqual(self._count_ops(result.graph_module, abs_target), 1)