Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,18 @@ fbcode_target(_kind = runtime.python_library,
],
)

fbcode_target(_kind = runtime.python_library,
name = "cse_pass",
srcs = [
"cse_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects/edge:lib",
],
)

fbcode_target(_kind = runtime.python_library,
name = "convert_constant_dim_order_pass",
srcs = [
Expand Down
201 changes: 201 additions & 0 deletions exir/passes/cse_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

import torch
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torch.fx.node import Node


class CSEPass(ExportPass):
"""
Common Subexpression Elimination using structural hashing (value numbering).

Deduplicates operations with identical computation structure, replacing
redundant computations with references to previously computed results.

Uses recursive structural keys: two nodes are equivalent if they have the
same op and their inputs are structurally equivalent. This naturally handles
chains like item(select(select(x, 0, a), 0, b)) without special cases.

Safety is determined automatically via op schema introspection:
- For OpOverload targets (aten ops): checks _schema for mutating arguments
- For operator.* targets (SymInt arithmetic): always safe
- A small denylist covers non-deterministic ops (rand, dropout, etc.)
"""

# Ops that are pure (no mutation per schema) but non-deterministic:
# same inputs can produce different outputs.
UNSAFE_OPS = frozenset(
[
"aten::rand",
"aten::rand_like",
"aten::randn",
"aten::randn_like",
"aten::randint",
"aten::randint_like",
"aten::randperm",
"aten::bernoulli",
"aten::dropout",
"aten::native_dropout",
"aten::multinomial",
"aten::normal",
"aten::uniform",
]
)

def _is_safe_target(self, target) -> bool:
"""
Determine if an op target is safe for CSE.

Uses schema introspection for OpOverload targets: if no argument
has alias_info with is_write=True, the op doesn't mutate and is
safe (unless it's non-deterministic).

Python operator.* targets are always safe (pure scalar arithmetic).
"""
# Python operator module functions are always pure and deterministic
if getattr(target, "__module__", None) == "_operator":
return True

# EdgeOpOverload targets (edge dialect ops)
if isinstance(target, (torch._ops.OpOverload, EdgeOpOverload)):
schema_name = target._schema.name

# Only trust schema introspection for aten:: ops.
# Custom op schemas (mlx::, torchao::, etc.) may not accurately
# annotate mutation or side effects — default to unsafe.
if not schema_name.startswith("aten::"):
return False

# Check denylist for non-deterministic/side-effecting ops
if schema_name in self.UNSAFE_OPS:
return False

# Check schema for mutating arguments
for arg in target._schema.arguments:
if arg.alias_info is not None and arg.alias_info.is_write:
return False

return True

return False

def call(self, graph_module: GraphModule) -> PassResult:
graph = graph_module.graph

# Discover graph output nodes — includes buffer mutation outputs
# (e.g. index_copy for KV cache). These must never be deduplicated
# because the graph signature references them by name for writeback.
output_node = next(n for n in graph.nodes if n.op == "output")
self._output_nodes: set[Node] = set()
for arg in output_node.args[0]:
if isinstance(arg, Node):
self._output_nodes.add(arg)

self._vn_cache: dict[Node, int] = {} # Node → value number
self._safe_cache: dict[Any, bool] = {} # Cache for _is_safe_target
self._sig_to_vn: dict[Any, int] = {} # flat signature → value number
self._vn_to_node: dict[int, Node] = {} # value number → canonical node
self._next_vn = 0
modified = False

for node in list(graph.nodes):
vn = self._value_number(node)

if vn in self._vn_to_node:
canonical = self._vn_to_node[vn]
if canonical is not node:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
modified = True
else:
self._vn_to_node[vn] = node

if modified:
graph.eliminate_dead_code()
graph.lint()

return PassResult(graph_module, modified)

def _is_safe(self, target) -> bool:
"""Cached version of _is_safe_target."""
tid = id(target)
if tid not in self._safe_cache:
self._safe_cache[tid] = self._is_safe_target(target)
return self._safe_cache[tid]

def _new_vn(self) -> int:
"""Allocate a fresh unique value number."""
vn = self._next_vn
self._next_vn += 1
return vn

def _value_number(self, node: Node) -> int:
"""
Assign an integer value number to a node (global value numbering).

Two nodes with the same value number are structurally equivalent
and can be deduplicated. All signature tuples are flat (contain
only ints and scalars), so hashing is O(n_args) not O(graph_depth).
"""
if node in self._vn_cache:
return self._vn_cache[node]

if node.op != "call_function":
vn = self._new_vn()
elif node in self._output_nodes:
# Graph output node (includes buffer mutations like index_copy).
# Must keep unique — graph signature references this node by name.
vn = self._new_vn()
elif not self._is_safe(node.target):
vn = self._new_vn()
else:
try:
args_sig = tuple(self._make_hashable(a) for a in node.args)
kwargs_sig = tuple(
sorted((k, self._make_hashable(v)) for k, v in node.kwargs.items())
)
sig = (node.target, args_sig, kwargs_sig)

if sig in self._sig_to_vn:
vn = self._sig_to_vn[sig]
else:
vn = self._new_vn()
self._sig_to_vn[sig] = vn
except TypeError:
vn = self._new_vn()

self._vn_cache[node] = vn
return vn

def _make_hashable(self, obj) -> Any:
"""Convert args/kwargs to a hashable form.

For Node args, returns the integer value number — keeping
all signature tuples flat and O(1) to hash.
"""
if isinstance(obj, Node):
return self._value_number(obj)
elif isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, (list, tuple)):
return tuple(self._make_hashable(x) for x in obj)
elif isinstance(obj, dict):
return tuple(sorted((k, self._make_hashable(v)) for k, v in obj.items()))
elif isinstance(obj, torch.dtype):
return obj
elif isinstance(obj, torch.device):
return str(obj)
elif isinstance(obj, torch.layout):
return str(obj)
elif isinstance(obj, torch.memory_format):
return str(obj)
else:
raise TypeError(f"Cannot make {type(obj)} hashable")
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ python_unittest(
"//executorch/exir/dialects/edge:lib",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:constant_prop_pass",
"//executorch/exir/passes:cse_pass",
"//executorch/exir/passes:debug_handle_generator_pass",
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
"//executorch/exir/passes:lib",
Expand Down
91 changes: 91 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
ToOutVarPass,
)
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.cse_pass import CSEPass
from executorch.exir.passes.debug_handle_generator_pass import (
DebugHandleGeneratorPass,
generate_missing_debug_handles,
Expand Down Expand Up @@ -2499,3 +2500,93 @@ def test_convert_constant_dim_order_to_contiguous(self):
modified_const.is_contiguous(),
f"Constant should be contiguous after pass, got strides {modified_const.stride()}",
)


class TestCSEPass(unittest.TestCase):
"""Tests for Common Subexpression Elimination pass."""

@staticmethod
def _to_edge_gm(module, example_inputs):
ep = torch.export.export(module, example_inputs, strict=False)
edge = to_edge(ep)
return edge.exported_program().graph_module

@staticmethod
def _count_ops(gm, target):
return sum(
1 for n in gm.graph.nodes if n.op == "call_function" and n.target == target
)

def test_duplicate_unary_ops_deduplicated(self):
"""Two identical neg(x) ops should be merged into one."""

class M(torch.nn.Module):
def forward(self, x):
a = torch.neg(x)
b = torch.neg(x)
return a + b

gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
target = exir_ops.edge.aten.neg.default
before = self._count_ops(gm, target)

if before < 2:
self.skipTest("Export already deduplicated neg ops")

result = CSEPass()(gm)

self.assertTrue(result.modified)
self.assertEqual(self._count_ops(result.graph_module, target), 1)

def test_different_ops_not_merged(self):
"""neg(x) and abs(x) should not be merged."""

class M(torch.nn.Module):
def forward(self, x):
return torch.neg(x) + torch.abs(x)

gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
result = CSEPass()(gm)
self.assertFalse(result.modified)

def test_same_op_different_inputs_not_merged(self):
"""neg(x) and neg(y) should not be merged."""

class M(torch.nn.Module):
def forward(self, x, y):
return torch.neg(x) + torch.neg(y)

gm = self._to_edge_gm(M(), (torch.randn(4, 4), torch.randn(4, 4)))
result = CSEPass()(gm)
self.assertFalse(result.modified)

def test_noop_when_no_duplicates(self):
class M(torch.nn.Module):
def forward(self, x):
return x + 1

gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
result = CSEPass()(gm)
self.assertFalse(result.modified)

def test_duplicate_chains_deduplicated(self):
"""Duplicate multi-op chains should be merged via structural hashing."""

class M(torch.nn.Module):
def forward(self, x):
a = torch.neg(torch.abs(x))
b = torch.neg(torch.abs(x))
return a + b

gm = self._to_edge_gm(M(), (torch.randn(4, 4),))
neg_target = exir_ops.edge.aten.neg.default
abs_target = exir_ops.edge.aten.abs.default

if self._count_ops(gm, neg_target) < 2:
self.skipTest("Export already deduplicated chains")

result = CSEPass()(gm)

self.assertTrue(result.modified)
self.assertEqual(self._count_ops(result.graph_module, neg_target), 1)
self.assertEqual(self._count_ops(result.graph_module, abs_target), 1)
Loading