Skip to content

Commit 868d955

Browse files
committed
up
1 parent 2c73c84 commit 868d955

2 files changed

Lines changed: 2 additions & 319 deletions

File tree

backends/mlx/passes.py

Lines changed: 2 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from dataclasses import dataclass
12-
from typing import Any, List, Optional
12+
from typing import List, Optional
1313

1414
import torch
1515
from executorch.backends.mlx.pattern_utils import (
@@ -20,8 +20,8 @@
2020
walk_back,
2121
)
2222
from executorch.exir.dialects._ops import ops as exir_ops
23-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2423
from executorch.exir.pass_base import ExportPass, PassResult
24+
from executorch.exir.passes.cse_pass import CSEPass
2525
from torch.fx import GraphModule, Node
2626

2727

@@ -506,191 +506,3 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
506506
graph.lint()
507507

508508
return PassResult(graph_module, modified)
509-
510-
511-
class CSEPass(ExportPass):
512-
"""
513-
Common Subexpression Elimination using structural hashing (value numbering).
514-
515-
Deduplicates operations with identical computation structure, replacing
516-
redundant computations with references to previously computed results.
517-
518-
Uses recursive structural keys: two nodes are equivalent if they have the
519-
same op and their inputs are structurally equivalent. This naturally handles
520-
chains like item(select(select(x, 0, a), 0, b)) without special cases.
521-
522-
Safety is determined automatically via op schema introspection:
523-
- For OpOverload targets (aten ops): checks _schema for mutating arguments
524-
- For operator.* targets (SymInt arithmetic): always safe
525-
- A small denylist covers non-deterministic ops (rand, dropout, etc.)
526-
"""
527-
528-
# Ops that are pure (no mutation per schema) but non-deterministic:
529-
# same inputs can produce different outputs.
530-
UNSAFE_OPS = frozenset(
531-
[
532-
"aten::rand",
533-
"aten::rand_like",
534-
"aten::randn",
535-
"aten::randn_like",
536-
"aten::randint",
537-
"aten::randint_like",
538-
"aten::randperm",
539-
"aten::bernoulli",
540-
"aten::dropout",
541-
"aten::native_dropout",
542-
"aten::multinomial",
543-
"aten::normal",
544-
"aten::uniform",
545-
]
546-
)
547-
548-
def _is_safe_target(self, target) -> bool:
549-
"""
550-
Determine if an op target is safe for CSE.
551-
552-
Uses schema introspection for OpOverload targets: if no argument
553-
has alias_info with is_write=True, the op doesn't mutate and is
554-
safe (unless it's non-deterministic).
555-
556-
Python operator.* targets are always safe (pure scalar arithmetic).
557-
"""
558-
# Python operator module functions are always pure and deterministic
559-
if getattr(target, "__module__", None) == "_operator":
560-
return True
561-
562-
# EdgeOpOverload targets (edge dialect ops)
563-
if isinstance(target, (torch._ops.OpOverload, EdgeOpOverload)):
564-
schema_name = target._schema.name
565-
566-
# Only trust schema introspection for aten:: ops.
567-
# Custom op schemas (mlx::, torchao::, etc.) may not accurately
568-
# annotate mutation or side effects — default to unsafe.
569-
if not schema_name.startswith("aten::"):
570-
return False
571-
572-
# Check denylist for non-deterministic/side-effecting ops
573-
if schema_name in self.UNSAFE_OPS:
574-
return False
575-
576-
# Check schema for mutating arguments
577-
for arg in target._schema.arguments:
578-
if arg.alias_info is not None and arg.alias_info.is_write:
579-
return False
580-
581-
return True
582-
583-
return False
584-
585-
def call(self, graph_module: GraphModule) -> PassResult:
586-
graph = graph_module.graph
587-
588-
# Discover graph output nodes — includes buffer mutation outputs
589-
# (e.g. index_copy for KV cache). These must never be deduplicated
590-
# because the graph signature references them by name for writeback.
591-
output_node = next(n for n in graph.nodes if n.op == "output")
592-
self._output_nodes: set[Node] = set()
593-
for arg in output_node.args[0]:
594-
if isinstance(arg, Node):
595-
self._output_nodes.add(arg)
596-
597-
self._vn_cache: dict[Node, int] = {} # Node → value number
598-
self._safe_cache: dict[Any, bool] = {} # Cache for _is_safe_target
599-
self._sig_to_vn: dict[Any, int] = {} # flat signature → value number
600-
self._vn_to_node: dict[int, Node] = {} # value number → canonical node
601-
self._next_vn = 0
602-
modified = False
603-
604-
for node in list(graph.nodes):
605-
vn = self._value_number(node)
606-
607-
if vn in self._vn_to_node:
608-
canonical = self._vn_to_node[vn]
609-
if canonical is not node:
610-
node.replace_all_uses_with(canonical)
611-
graph.erase_node(node)
612-
modified = True
613-
else:
614-
self._vn_to_node[vn] = node
615-
616-
if modified:
617-
graph.eliminate_dead_code()
618-
graph.lint()
619-
620-
return PassResult(graph_module, modified)
621-
622-
def _is_safe(self, target) -> bool:
623-
"""Cached version of _is_safe_target."""
624-
tid = id(target)
625-
if tid not in self._safe_cache:
626-
self._safe_cache[tid] = self._is_safe_target(target)
627-
return self._safe_cache[tid]
628-
629-
def _new_vn(self) -> int:
630-
"""Allocate a fresh unique value number."""
631-
vn = self._next_vn
632-
self._next_vn += 1
633-
return vn
634-
635-
def _value_number(self, node: Node) -> int:
636-
"""
637-
Assign an integer value number to a node (global value numbering).
638-
639-
Two nodes with the same value number are structurally equivalent
640-
and can be deduplicated. All signature tuples are flat (contain
641-
only ints and scalars), so hashing is O(n_args) not O(graph_depth).
642-
"""
643-
if node in self._vn_cache:
644-
return self._vn_cache[node]
645-
646-
if node.op != "call_function":
647-
vn = self._new_vn()
648-
elif node in self._output_nodes:
649-
# Graph output node (includes buffer mutations like index_copy).
650-
# Must keep unique — graph signature references this node by name.
651-
vn = self._new_vn()
652-
elif not self._is_safe(node.target):
653-
vn = self._new_vn()
654-
else:
655-
try:
656-
args_sig = tuple(self._make_hashable(a) for a in node.args)
657-
kwargs_sig = tuple(
658-
sorted((k, self._make_hashable(v)) for k, v in node.kwargs.items())
659-
)
660-
sig = (node.target, args_sig, kwargs_sig)
661-
662-
if sig in self._sig_to_vn:
663-
vn = self._sig_to_vn[sig]
664-
else:
665-
vn = self._new_vn()
666-
self._sig_to_vn[sig] = vn
667-
except TypeError:
668-
vn = self._new_vn()
669-
670-
self._vn_cache[node] = vn
671-
return vn
672-
673-
def _make_hashable(self, obj) -> Any:
674-
"""Convert args/kwargs to a hashable form.
675-
676-
For Node args, returns the integer value number — keeping
677-
all signature tuples flat and O(1) to hash.
678-
"""
679-
if isinstance(obj, Node):
680-
return self._value_number(obj)
681-
elif isinstance(obj, (int, float, str, bool, type(None))):
682-
return obj
683-
elif isinstance(obj, (list, tuple)):
684-
return tuple(self._make_hashable(x) for x in obj)
685-
elif isinstance(obj, dict):
686-
return tuple(sorted((k, self._make_hashable(v)) for k, v in obj.items()))
687-
elif isinstance(obj, torch.dtype):
688-
return obj
689-
elif isinstance(obj, torch.device):
690-
return str(obj)
691-
elif isinstance(obj, torch.layout):
692-
return str(obj)
693-
elif isinstance(obj, torch.memory_format):
694-
return str(obj)
695-
else:
696-
raise TypeError(f"Cannot make {type(obj)} hashable")

0 commit comments

Comments
 (0)