Skip to content

Commit 1f4ad07

Browse files
authored
Introduce CSE pass to ExecuTorch (#17752)
This introduces a CSE pass to ExecuTorch, which eliminates common subexpressions that occur in exported programs. This pass was first developed as part of the MLX delegate (#16718) to optimize transformers, but I'm introducing it to ExecuTorch more generally because I believe it could benefit many other backends. Examples of common subexpressions that occur in transformers: * Repeated mask constructions per layer (only needs to be done once) * Repeated extraction of symints from 1d tensors for cache position (emits .item calls, which cause tensor materialization) This pass eliminates these inefficiencies without having to rewrite the model.
1 parent b24a88e commit 1f4ad07

4 files changed

Lines changed: 305 additions & 0 deletions

File tree

exir/passes/BUCK

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,18 @@ fbcode_target(_kind = runtime.python_library,
418418
],
419419
)
420420

421+
fbcode_target(_kind = runtime.python_library,
422+
name = "cse_pass",
423+
srcs = [
424+
"cse_pass.py",
425+
],
426+
deps = [
427+
"//caffe2:torch",
428+
"//executorch/exir:pass_base",
429+
"//executorch/exir/dialects/edge:lib",
430+
],
431+
)
432+
421433
fbcode_target(_kind = runtime.python_library,
422434
name = "convert_constant_dim_order_pass",
423435
srcs = [

exir/passes/cse_pass.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any
8+
9+
import torch
10+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
from torch.fx import GraphModule
13+
from torch.fx.node import Node
14+
15+
16+
class CSEPass(ExportPass):
17+
"""
18+
Common Subexpression Elimination using structural hashing (value numbering).
19+
20+
Deduplicates operations with identical computation structure, replacing
21+
redundant computations with references to previously computed results.
22+
23+
Uses recursive structural keys: two nodes are equivalent if they have the
24+
same op and their inputs are structurally equivalent. This naturally handles
25+
chains like item(select(select(x, 0, a), 0, b)) without special cases.
26+
27+
Safety is determined automatically via op schema introspection:
28+
- For OpOverload targets (aten ops): checks _schema for mutating arguments
29+
- For operator.* targets (SymInt arithmetic): always safe
30+
- A small denylist covers non-deterministic ops (rand, dropout, etc.)
31+
"""
32+
33+
# Ops that are pure (no mutation per schema) but non-deterministic:
34+
# same inputs can produce different outputs.
35+
UNSAFE_OPS = frozenset(
36+
[
37+
"aten::rand",
38+
"aten::rand_like",
39+
"aten::randn",
40+
"aten::randn_like",
41+
"aten::randint",
42+
"aten::randint_like",
43+
"aten::randperm",
44+
"aten::bernoulli",
45+
"aten::dropout",
46+
"aten::native_dropout",
47+
"aten::multinomial",
48+
"aten::normal",
49+
"aten::uniform",
50+
]
51+
)
52+
53+
def _is_safe_target(self, target) -> bool:
54+
"""
55+
Determine if an op target is safe for CSE.
56+
57+
Uses schema introspection for OpOverload targets: if no argument
58+
has alias_info with is_write=True, the op doesn't mutate and is
59+
safe (unless it's non-deterministic).
60+
61+
Python operator.* targets are always safe (pure scalar arithmetic).
62+
"""
63+
# Python operator module functions are always pure and deterministic
64+
if getattr(target, "__module__", None) == "_operator":
65+
return True
66+
67+
# EdgeOpOverload targets (edge dialect ops)
68+
if isinstance(target, (torch._ops.OpOverload, EdgeOpOverload)):
69+
schema_name = target._schema.name
70+
71+
# Only trust schema introspection for aten:: ops.
72+
# Custom op schemas (mlx::, torchao::, etc.) may not accurately
73+
# annotate mutation or side effects — default to unsafe.
74+
if not schema_name.startswith("aten::"):
75+
return False
76+
77+
# Check denylist for non-deterministic/side-effecting ops
78+
if schema_name in self.UNSAFE_OPS:
79+
return False
80+
81+
# Check schema for mutating arguments
82+
for arg in target._schema.arguments:
83+
if arg.alias_info is not None and arg.alias_info.is_write:
84+
return False
85+
86+
return True
87+
88+
return False
89+
90+
def call(self, graph_module: GraphModule) -> PassResult:
91+
graph = graph_module.graph
92+
93+
# Discover graph output nodes — includes buffer mutation outputs
94+
# (e.g. index_copy for KV cache). These must never be deduplicated
95+
# because the graph signature references them by name for writeback.
96+
output_node = next(n for n in graph.nodes if n.op == "output")
97+
self._output_nodes: set[Node] = set()
98+
for arg in output_node.args[0]:
99+
if isinstance(arg, Node):
100+
self._output_nodes.add(arg)
101+
102+
self._vn_cache: dict[Node, int] = {} # Node → value number
103+
self._safe_cache: dict[Any, bool] = {} # Cache for _is_safe_target
104+
self._sig_to_vn: dict[Any, int] = {} # flat signature → value number
105+
self._vn_to_node: dict[int, Node] = {} # value number → canonical node
106+
self._next_vn = 0
107+
modified = False
108+
109+
for node in list(graph.nodes):
110+
vn = self._value_number(node)
111+
112+
if vn in self._vn_to_node:
113+
canonical = self._vn_to_node[vn]
114+
if canonical is not node:
115+
node.replace_all_uses_with(canonical)
116+
graph.erase_node(node)
117+
modified = True
118+
else:
119+
self._vn_to_node[vn] = node
120+
121+
if modified:
122+
graph.eliminate_dead_code()
123+
graph.lint()
124+
125+
return PassResult(graph_module, modified)
126+
127+
def _is_safe(self, target) -> bool:
128+
"""Cached version of _is_safe_target."""
129+
tid = id(target)
130+
if tid not in self._safe_cache:
131+
self._safe_cache[tid] = self._is_safe_target(target)
132+
return self._safe_cache[tid]
133+
134+
def _new_vn(self) -> int:
135+
"""Allocate a fresh unique value number."""
136+
vn = self._next_vn
137+
self._next_vn += 1
138+
return vn
139+
140+
def _value_number(self, node: Node) -> int:
141+
"""
142+
Assign an integer value number to a node (global value numbering).
143+
144+
Two nodes with the same value number are structurally equivalent
145+
and can be deduplicated. All signature tuples are flat (contain
146+
only ints and scalars), so hashing is O(n_args) not O(graph_depth).
147+
"""
148+
if node in self._vn_cache:
149+
return self._vn_cache[node]
150+
151+
if node.op != "call_function":
152+
vn = self._new_vn()
153+
elif node in self._output_nodes:
154+
# Graph output node (includes buffer mutations like index_copy).
155+
# Must keep unique — graph signature references this node by name.
156+
vn = self._new_vn()
157+
elif not self._is_safe(node.target):
158+
vn = self._new_vn()
159+
else:
160+
try:
161+
args_sig = tuple(self._make_hashable(a) for a in node.args)
162+
kwargs_sig = tuple(
163+
sorted((k, self._make_hashable(v)) for k, v in node.kwargs.items())
164+
)
165+
sig = (node.target, args_sig, kwargs_sig)
166+
167+
if sig in self._sig_to_vn:
168+
vn = self._sig_to_vn[sig]
169+
else:
170+
vn = self._new_vn()
171+
self._sig_to_vn[sig] = vn
172+
except TypeError:
173+
vn = self._new_vn()
174+
175+
self._vn_cache[node] = vn
176+
return vn
177+
178+
def _make_hashable(self, obj) -> Any:
179+
"""Convert args/kwargs to a hashable form.
180+
181+
For Node args, returns the integer value number — keeping
182+
all signature tuples flat and O(1) to hash.
183+
"""
184+
if isinstance(obj, Node):
185+
return self._value_number(obj)
186+
elif isinstance(obj, (int, float, str, bool, type(None))):
187+
return obj
188+
elif isinstance(obj, (list, tuple)):
189+
return tuple(self._make_hashable(x) for x in obj)
190+
elif isinstance(obj, dict):
191+
return tuple(sorted((k, self._make_hashable(v)) for k, v in obj.items()))
192+
elif isinstance(obj, torch.dtype):
193+
return obj
194+
elif isinstance(obj, torch.device):
195+
return str(obj)
196+
elif isinstance(obj, torch.layout):
197+
return str(obj)
198+
elif isinstance(obj, torch.memory_format):
199+
return str(obj)
200+
else:
201+
raise TypeError(f"Cannot make {type(obj)} hashable")

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ python_unittest(
224224
"//executorch/exir/dialects/edge:lib",
225225
"//executorch/exir/emit:lib",
226226
"//executorch/exir/passes:constant_prop_pass",
227+
"//executorch/exir/passes:cse_pass",
227228
"//executorch/exir/passes:debug_handle_generator_pass",
228229
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
229230
"//executorch/exir/passes:lib",

exir/tests/test_passes.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
ToOutVarPass,
5454
)
5555
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
56+
from executorch.exir.passes.cse_pass import CSEPass
5657
from executorch.exir.passes.debug_handle_generator_pass import (
5758
DebugHandleGeneratorPass,
5859
generate_missing_debug_handles,
@@ -2499,3 +2500,93 @@ def test_convert_constant_dim_order_to_contiguous(self):
24992500
modified_const.is_contiguous(),
25002501
f"Constant should be contiguous after pass, got strides {modified_const.stride()}",
25012502
)
2503+
2504+
2505+
class TestCSEPass(unittest.TestCase):
2506+
"""Tests for Common Subexpression Elimination pass."""
2507+
2508+
@staticmethod
2509+
def _to_edge_gm(module, example_inputs):
2510+
ep = torch.export.export(module, example_inputs, strict=False)
2511+
edge = to_edge(ep)
2512+
return edge.exported_program().graph_module
2513+
2514+
@staticmethod
2515+
def _count_ops(gm, target):
2516+
return sum(
2517+
1 for n in gm.graph.nodes if n.op == "call_function" and n.target == target
2518+
)
2519+
2520+
def test_duplicate_unary_ops_deduplicated(self):
2521+
"""Two identical neg(x) ops should be merged into one."""
2522+
2523+
class M(torch.nn.Module):
2524+
def forward(self, x):
2525+
a = torch.neg(x)
2526+
b = torch.neg(x)
2527+
return a + b
2528+
2529+
gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
2530+
target = exir_ops.edge.aten.neg.default
2531+
before = self._count_ops(gm, target)
2532+
2533+
if before < 2:
2534+
self.skipTest("Export already deduplicated neg ops")
2535+
2536+
result = CSEPass()(gm)
2537+
2538+
self.assertTrue(result.modified)
2539+
self.assertEqual(self._count_ops(result.graph_module, target), 1)
2540+
2541+
def test_different_ops_not_merged(self):
2542+
"""neg(x) and abs(x) should not be merged."""
2543+
2544+
class M(torch.nn.Module):
2545+
def forward(self, x):
2546+
return torch.neg(x) + torch.abs(x)
2547+
2548+
gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
2549+
result = CSEPass()(gm)
2550+
self.assertFalse(result.modified)
2551+
2552+
def test_same_op_different_inputs_not_merged(self):
2553+
"""neg(x) and neg(y) should not be merged."""
2554+
2555+
class M(torch.nn.Module):
2556+
def forward(self, x, y):
2557+
return torch.neg(x) + torch.neg(y)
2558+
2559+
gm = self._to_edge_gm(M(), (torch.randn(4, 4), torch.randn(4, 4)))
2560+
result = CSEPass()(gm)
2561+
self.assertFalse(result.modified)
2562+
2563+
def test_noop_when_no_duplicates(self):
2564+
class M(torch.nn.Module):
2565+
def forward(self, x):
2566+
return x + 1
2567+
2568+
gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
2569+
result = CSEPass()(gm)
2570+
self.assertFalse(result.modified)
2571+
2572+
def test_duplicate_chains_deduplicated(self):
2573+
"""Duplicate multi-op chains should be merged via structural hashing."""
2574+
2575+
class M(torch.nn.Module):
2576+
def forward(self, x):
2577+
a = torch.neg(torch.abs(x))
2578+
b = torch.neg(torch.abs(x))
2579+
return a + b
2580+
2581+
gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
2582+
neg_target = exir_ops.edge.aten.neg.default
2583+
abs_target = exir_ops.edge.aten.abs.default
2584+
2585+
if self._count_ops(gm, neg_target) < 2:
2586+
self.skipTest("Export already deduplicated chains")
2587+
2588+
result = CSEPass()(gm)
2589+
2590+
self.assertTrue(result.modified)
2591+
self.assertEqual(self._count_ops(result.graph_module, neg_target), 1)
2592+
self.assertEqual(self._count_ops(result.graph_module, abs_target), 1)

0 commit comments

Comments
 (0)