|
9 | 9 | """ |
10 | 10 |
|
11 | 11 | from dataclasses import dataclass |
12 | | -from typing import Any, List, Optional |
| 12 | +from typing import List, Optional |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from executorch.backends.mlx.pattern_utils import ( |
|
20 | 20 | walk_back, |
21 | 21 | ) |
22 | 22 | from executorch.exir.dialects._ops import ops as exir_ops |
23 | | -from executorch.exir.dialects.edge._ops import EdgeOpOverload |
24 | 23 | from executorch.exir.pass_base import ExportPass, PassResult |
| 24 | +from executorch.exir.passes.cse_pass import CSEPass |
25 | 25 | from torch.fx import GraphModule, Node |
26 | 26 |
|
27 | 27 |
|
@@ -506,191 +506,3 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 |
506 | 506 | graph.lint() |
507 | 507 |
|
508 | 508 | return PassResult(graph_module, modified) |
509 | | - |
510 | | - |
511 | | -class CSEPass(ExportPass): |
512 | | - """ |
513 | | - Common Subexpression Elimination using structural hashing (value numbering). |
514 | | -
|
515 | | - Deduplicates operations with identical computation structure, replacing |
516 | | - redundant computations with references to previously computed results. |
517 | | -
|
518 | | - Uses recursive structural keys: two nodes are equivalent if they have the |
519 | | - same op and their inputs are structurally equivalent. This naturally handles |
520 | | - chains like item(select(select(x, 0, a), 0, b)) without special cases. |
521 | | -
|
522 | | - Safety is determined automatically via op schema introspection: |
523 | | - - For OpOverload targets (aten ops): checks _schema for mutating arguments |
524 | | - - For operator.* targets (SymInt arithmetic): always safe |
525 | | - - A small denylist covers non-deterministic ops (rand, dropout, etc.) |
526 | | - """ |
527 | | - |
528 | | - # Ops that are pure (no mutation per schema) but non-deterministic: |
529 | | - # same inputs can produce different outputs. |
530 | | - UNSAFE_OPS = frozenset( |
531 | | - [ |
532 | | - "aten::rand", |
533 | | - "aten::rand_like", |
534 | | - "aten::randn", |
535 | | - "aten::randn_like", |
536 | | - "aten::randint", |
537 | | - "aten::randint_like", |
538 | | - "aten::randperm", |
539 | | - "aten::bernoulli", |
540 | | - "aten::dropout", |
541 | | - "aten::native_dropout", |
542 | | - "aten::multinomial", |
543 | | - "aten::normal", |
544 | | - "aten::uniform", |
545 | | - ] |
546 | | - ) |
547 | | - |
548 | | - def _is_safe_target(self, target) -> bool: |
549 | | - """ |
550 | | - Determine if an op target is safe for CSE. |
551 | | -
|
552 | | - Uses schema introspection for OpOverload targets: if no argument |
553 | | - has alias_info with is_write=True, the op doesn't mutate and is |
554 | | - safe (unless it's non-deterministic). |
555 | | -
|
556 | | - Python operator.* targets are always safe (pure scalar arithmetic). |
557 | | - """ |
558 | | - # Python operator module functions are always pure and deterministic |
559 | | - if getattr(target, "__module__", None) == "_operator": |
560 | | - return True |
561 | | - |
562 | | - # EdgeOpOverload targets (edge dialect ops) |
563 | | - if isinstance(target, (torch._ops.OpOverload, EdgeOpOverload)): |
564 | | - schema_name = target._schema.name |
565 | | - |
566 | | - # Only trust schema introspection for aten:: ops. |
567 | | - # Custom op schemas (mlx::, torchao::, etc.) may not accurately |
568 | | - # annotate mutation or side effects — default to unsafe. |
569 | | - if not schema_name.startswith("aten::"): |
570 | | - return False |
571 | | - |
572 | | - # Check denylist for non-deterministic/side-effecting ops |
573 | | - if schema_name in self.UNSAFE_OPS: |
574 | | - return False |
575 | | - |
576 | | - # Check schema for mutating arguments |
577 | | - for arg in target._schema.arguments: |
578 | | - if arg.alias_info is not None and arg.alias_info.is_write: |
579 | | - return False |
580 | | - |
581 | | - return True |
582 | | - |
583 | | - return False |
584 | | - |
585 | | - def call(self, graph_module: GraphModule) -> PassResult: |
586 | | - graph = graph_module.graph |
587 | | - |
588 | | - # Discover graph output nodes — includes buffer mutation outputs |
589 | | - # (e.g. index_copy for KV cache). These must never be deduplicated |
590 | | - # because the graph signature references them by name for writeback. |
591 | | - output_node = next(n for n in graph.nodes if n.op == "output") |
592 | | - self._output_nodes: set[Node] = set() |
593 | | - for arg in output_node.args[0]: |
594 | | - if isinstance(arg, Node): |
595 | | - self._output_nodes.add(arg) |
596 | | - |
597 | | - self._vn_cache: dict[Node, int] = {} # Node → value number |
598 | | - self._safe_cache: dict[Any, bool] = {} # Cache for _is_safe_target |
599 | | - self._sig_to_vn: dict[Any, int] = {} # flat signature → value number |
600 | | - self._vn_to_node: dict[int, Node] = {} # value number → canonical node |
601 | | - self._next_vn = 0 |
602 | | - modified = False |
603 | | - |
604 | | - for node in list(graph.nodes): |
605 | | - vn = self._value_number(node) |
606 | | - |
607 | | - if vn in self._vn_to_node: |
608 | | - canonical = self._vn_to_node[vn] |
609 | | - if canonical is not node: |
610 | | - node.replace_all_uses_with(canonical) |
611 | | - graph.erase_node(node) |
612 | | - modified = True |
613 | | - else: |
614 | | - self._vn_to_node[vn] = node |
615 | | - |
616 | | - if modified: |
617 | | - graph.eliminate_dead_code() |
618 | | - graph.lint() |
619 | | - |
620 | | - return PassResult(graph_module, modified) |
621 | | - |
622 | | - def _is_safe(self, target) -> bool: |
623 | | - """Cached version of _is_safe_target.""" |
624 | | - tid = id(target) |
625 | | - if tid not in self._safe_cache: |
626 | | - self._safe_cache[tid] = self._is_safe_target(target) |
627 | | - return self._safe_cache[tid] |
628 | | - |
629 | | - def _new_vn(self) -> int: |
630 | | - """Allocate a fresh unique value number.""" |
631 | | - vn = self._next_vn |
632 | | - self._next_vn += 1 |
633 | | - return vn |
634 | | - |
635 | | - def _value_number(self, node: Node) -> int: |
636 | | - """ |
637 | | - Assign an integer value number to a node (global value numbering). |
638 | | -
|
639 | | - Two nodes with the same value number are structurally equivalent |
640 | | - and can be deduplicated. All signature tuples are flat (contain |
641 | | - only ints and scalars), so hashing is O(n_args) not O(graph_depth). |
642 | | - """ |
643 | | - if node in self._vn_cache: |
644 | | - return self._vn_cache[node] |
645 | | - |
646 | | - if node.op != "call_function": |
647 | | - vn = self._new_vn() |
648 | | - elif node in self._output_nodes: |
649 | | - # Graph output node (includes buffer mutations like index_copy). |
650 | | - # Must keep unique — graph signature references this node by name. |
651 | | - vn = self._new_vn() |
652 | | - elif not self._is_safe(node.target): |
653 | | - vn = self._new_vn() |
654 | | - else: |
655 | | - try: |
656 | | - args_sig = tuple(self._make_hashable(a) for a in node.args) |
657 | | - kwargs_sig = tuple( |
658 | | - sorted((k, self._make_hashable(v)) for k, v in node.kwargs.items()) |
659 | | - ) |
660 | | - sig = (node.target, args_sig, kwargs_sig) |
661 | | - |
662 | | - if sig in self._sig_to_vn: |
663 | | - vn = self._sig_to_vn[sig] |
664 | | - else: |
665 | | - vn = self._new_vn() |
666 | | - self._sig_to_vn[sig] = vn |
667 | | - except TypeError: |
668 | | - vn = self._new_vn() |
669 | | - |
670 | | - self._vn_cache[node] = vn |
671 | | - return vn |
672 | | - |
673 | | - def _make_hashable(self, obj) -> Any: |
674 | | - """Convert args/kwargs to a hashable form. |
675 | | -
|
676 | | - For Node args, returns the integer value number — keeping |
677 | | - all signature tuples flat and O(1) to hash. |
678 | | - """ |
679 | | - if isinstance(obj, Node): |
680 | | - return self._value_number(obj) |
681 | | - elif isinstance(obj, (int, float, str, bool, type(None))): |
682 | | - return obj |
683 | | - elif isinstance(obj, (list, tuple)): |
684 | | - return tuple(self._make_hashable(x) for x in obj) |
685 | | - elif isinstance(obj, dict): |
686 | | - return tuple(sorted((k, self._make_hashable(v)) for k, v in obj.items())) |
687 | | - elif isinstance(obj, torch.dtype): |
688 | | - return obj |
689 | | - elif isinstance(obj, torch.device): |
690 | | - return str(obj) |
691 | | - elif isinstance(obj, torch.layout): |
692 | | - return str(obj) |
693 | | - elif isinstance(obj, torch.memory_format): |
694 | | - return str(obj) |
695 | | - else: |
696 | | - raise TypeError(f"Cannot make {type(obj)} hashable") |
0 commit comments