diff --git a/src/finchlite/finch_assembly/nodes.py b/src/finchlite/finch_assembly/nodes.py index 64a750bb..b3d18ad6 100644 --- a/src/finchlite/finch_assembly/nodes.py +++ b/src/finchlite/finch_assembly/nodes.py @@ -3,12 +3,12 @@ from typing import Any from ..algebra import return_type -from ..symbolic import Context, NamedTerm, Term, TermTree, ftype, literal_repr +from ..symbolic import Context, HashCons, NamedTerm, Term, TermTree, ftype, literal_repr from ..util import qual_str from .buffer import length_type -class AssemblyNode(Term): +class AssemblyNode(HashCons, Term): """ AssemblyNode diff --git a/src/finchlite/finch_einsum/nodes.py b/src/finchlite/finch_einsum/nodes.py index dbf44dbb..fa4dc421 100644 --- a/src/finchlite/finch_einsum/nodes.py +++ b/src/finchlite/finch_einsum/nodes.py @@ -3,11 +3,11 @@ from typing import Any, Self, cast from finchlite.algebra import ffunc -from finchlite.symbolic import Context, Term, TermTree +from finchlite.symbolic import Context, HashCons, Term, TermTree from finchlite.util.print import qual_str -class EinsumNode(Term): +class EinsumNode(HashCons, Term): @classmethod def head(cls): """Returns the head of the node.""" diff --git a/src/finchlite/finch_fused/nodes.py b/src/finchlite/finch_fused/nodes.py index 56722442..6f52d61e 100644 --- a/src/finchlite/finch_fused/nodes.py +++ b/src/finchlite/finch_fused/nodes.py @@ -7,7 +7,7 @@ from typing import Any, Self from ..algebra import return_type -from ..symbolic import Context, NamedTerm, Term, TermTree, ftype, literal_repr +from ..symbolic import Context, HashCons, NamedTerm, Term, TermTree, ftype, literal_repr from ..util import qual_str """ @@ -25,7 +25,7 @@ @dataclass(eq=True, frozen=True) -class FusedNode(Term, ABC): +class FusedNode(HashCons, Term, ABC): @classmethod def head(cls): return cls diff --git a/src/finchlite/finch_logic/nodes.py b/src/finchlite/finch_logic/nodes.py index 369989fc..1f44e850 100644 --- a/src/finchlite/finch_logic/nodes.py +++ b/src/finchlite/finch_logic/nodes.py @@ -10,6 +10,7 @@ Context, FType, FTyped, + HashCons, NamedTerm, Term, TermTree, @@ -109,7 +110,7 @@ def __eq__(self, other): @dataclass(eq=True, frozen=True) -class LogicNode(Term, ABC): +class LogicNode(HashCons, Term, ABC): """ LogicNode diff --git a/src/finchlite/finch_notation/nodes.py b/src/finchlite/finch_notation/nodes.py index dc6c6854..70f47f2e 100644 --- a/src/finchlite/finch_notation/nodes.py +++ b/src/finchlite/finch_notation/nodes.py @@ -6,12 +6,12 @@ from ..algebra import return_type from ..finch_assembly import AssemblyNode -from ..symbolic import Context, FType, NamedTerm, Term, TermTree, ftype, literal_repr +from ..symbolic import Context, FType, HashCons, NamedTerm, Term, TermTree, ftype, literal_repr from ..util import qual_str @dataclass(eq=True, frozen=True) -class NotationNode(Term, ABC): +class NotationNode(HashCons, Term, ABC): """ NotationNode diff --git a/src/finchlite/symbolic/__init__.py b/src/finchlite/symbolic/__init__.py index 1fb1a09f..5581f76d 100644 --- a/src/finchlite/symbolic/__init__.py +++ b/src/finchlite/symbolic/__init__.py @@ -12,6 +12,7 @@ ) from .stage import Stage from .term import ( + HashCons, Term, TermTree, literal_repr, @@ -21,6 +22,7 @@ __all__ = [ "BasicBlock", "Chain", + "HashCons", "Context", "ControlFlowGraph", "DataFlowAnalysis", diff --git a/src/finchlite/symbolic/term.py b/src/finchlite/symbolic/term.py index d743bcda..4a335232 100644 --- a/src/finchlite/symbolic/term.py +++ b/src/finchlite/symbolic/term.py @@ -1,5 +1,7 @@ from __future__ import annotations +import inspect +import weakref from abc import ABC, abstractmethod from dataclasses import dataclass from inspect import isbuiltin, isclass, isfunction @@ -45,6 +47,40 @@ def recurse(node: Term) -> Term: """ +hash_cons_table: weakref.WeakValueDictionary = weakref.WeakValueDictionary() +_sig_cache: dict = {} + + +def _key_part(a: object) -> object: + if isinstance(a, Term): + return id(a) + if isinstance(a, tuple): + return tuple(_key_part(x) for x in a) + return (type(a), a) + + +class HashCons: + def __new__(cls, *args, **kwargs): + try: + if kwargs: + sig = _sig_cache.get(cls) + if sig is None: + _sig_cache[cls] = sig = inspect.signature(cls) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + norm = tuple(bound.arguments.values()) + else: + norm = args + key = (cls,) + tuple(_key_part(a) for a in norm) + obj = hash_cons_table.get(key) + if obj is None: + obj = object.__new__(cls) + hash_cons_table[key] = obj + return obj + except TypeError: + return object.__new__(cls) + + class Term: @abstractmethod def head(self) -> Any: diff --git a/tests/test_traversal.py b/tests/test_traversal.py index 41c49c2a..96e48ca6 100644 --- a/tests/test_traversal.py +++ b/tests/test_traversal.py @@ -34,15 +34,12 @@ def test_preorder_dfs(): {"Plan": 1, "Produces": 1, "MapJoin": 1, "Table": 2, "Literal": 2, "Field": 5} ) - pos = {} - for i, obj in enumerate(preorder): - k = id(obj) - if k in pos: - continue - pos[k] = i - for node in preorder: + for i, node in enumerate(preorder): for child in getattr(node, "children", ()): - assert pos[id(node)] < pos[id(child)] + child_pos = next( + (j for j, n in enumerate(preorder) if n is child and j > i), None + ) + assert child_pos is not None def test_postorder_dfs(): @@ -75,15 +72,12 @@ def test_postorder_dfs(): {"Plan": 1, "Produces": 1, "MapJoin": 1, "Table": 2, "Literal": 2, "Field": 5} ) - pos = {} - for i, obj in enumerate(postorder): - k = id(obj) - if k in pos: - continue - pos[k] = i - for node in postorder: + for i, node in enumerate(postorder): for child in getattr(node, "children", ()): - assert pos[id(child)] < pos[id(node)] + child_pos = next( + (j for j, n in enumerate(postorder) if n is child and j < i), None + ) + assert child_pos is not None def test_intree():