Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/finchlite/finch_assembly/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Any

from ..algebra import return_type
from ..symbolic import Context, NamedTerm, Term, TermTree, ftype, literal_repr
from ..symbolic import Context, HashCons, NamedTerm, Term, TermTree, ftype, literal_repr
from ..util import qual_str
from .buffer import length_type


class AssemblyNode(Term):
class AssemblyNode(HashCons, Term):
"""
AssemblyNode

Expand Down
4 changes: 2 additions & 2 deletions src/finchlite/finch_einsum/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Any, Self, cast

from finchlite.algebra import ffunc
from finchlite.symbolic import Context, Term, TermTree
from finchlite.symbolic import Context, HashCons, Term, TermTree
from finchlite.util.print import qual_str


class EinsumNode(Term):
class EinsumNode(HashCons, Term):
@classmethod
def head(cls):
"""Returns the head of the node."""
Expand Down
4 changes: 2 additions & 2 deletions src/finchlite/finch_fused/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Self

from ..algebra import return_type
from ..symbolic import Context, NamedTerm, Term, TermTree, ftype, literal_repr
from ..symbolic import Context, HashCons, NamedTerm, Term, TermTree, ftype, literal_repr
from ..util import qual_str

"""
Expand All @@ -25,7 +25,7 @@


@dataclass(eq=True, frozen=True)
class FusedNode(Term, ABC):
class FusedNode(HashCons, Term, ABC):
@classmethod
def head(cls):
return cls
Expand Down
3 changes: 2 additions & 1 deletion src/finchlite/finch_logic/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Context,
FType,
FTyped,
HashCons,
NamedTerm,
Term,
TermTree,
Expand Down Expand Up @@ -109,7 +110,7 @@ def __eq__(self, other):


@dataclass(eq=True, frozen=True)
class LogicNode(Term, ABC):
class LogicNode(HashCons, Term, ABC):
"""
LogicNode

Expand Down
4 changes: 2 additions & 2 deletions src/finchlite/finch_notation/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from ..algebra import return_type
from ..finch_assembly import AssemblyNode
from ..symbolic import Context, FType, NamedTerm, Term, TermTree, ftype, literal_repr
from ..symbolic import Context, FType, HashCons, NamedTerm, Term, TermTree, ftype, literal_repr
from ..util import qual_str


@dataclass(eq=True, frozen=True)
class NotationNode(Term, ABC):
class NotationNode(HashCons, Term, ABC):
"""
NotationNode

Expand Down
2 changes: 2 additions & 0 deletions src/finchlite/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from .stage import Stage
from .term import (
HashCons,
Term,
TermTree,
literal_repr,
Expand All @@ -21,6 +22,7 @@
__all__ = [
"BasicBlock",
"Chain",
"HashCons",
"Context",
"ControlFlowGraph",
"DataFlowAnalysis",
Expand Down
36 changes: 36 additions & 0 deletions src/finchlite/symbolic/term.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import inspect
import weakref
from abc import ABC, abstractmethod
from dataclasses import dataclass
from inspect import isbuiltin, isclass, isfunction
Expand Down Expand Up @@ -45,6 +47,40 @@ def recurse(node: Term) -> Term:
"""


hash_cons_table: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
_sig_cache: dict = {}


def _key_part(a: object) -> object:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method would be disaggregated across child classes of HashCons

if isinstance(a, Term):
return id(a)
if isinstance(a, tuple):
return tuple(_key_part(x) for x in a)
return (type(a), a)


class HashCons:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an abstractmethod called Hash_keys which defines which keys are used in the hashcons object

def __new__(cls, *args, **kwargs):
try:
if kwargs:
sig = _sig_cache.get(cls)
if sig is None:
_sig_cache[cls] = sig = inspect.signature(cls)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
norm = tuple(bound.arguments.values())
else:
norm = args
key = (cls,) + tuple(_key_part(a) for a in norm)
obj = hash_cons_table.get(key)
if obj is None:
obj = object.__new__(cls)
hash_cons_table[key] = obj
return obj
except TypeError:
return object.__new__(cls)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define eq and hash as id



class Term:
@abstractmethod
def head(self) -> Any:
Expand Down
26 changes: 10 additions & 16 deletions tests/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ def test_preorder_dfs():
{"Plan": 1, "Produces": 1, "MapJoin": 1, "Table": 2, "Literal": 2, "Field": 5}
)

pos = {}
for i, obj in enumerate(preorder):
k = id(obj)
if k in pos:
continue
pos[k] = i
for node in preorder:
for i, node in enumerate(preorder):
for child in getattr(node, "children", ()):
assert pos[id(node)] < pos[id(child)]
child_pos = next(
(j for j, n in enumerate(preorder) if n is child and j > i), None
)
assert child_pos is not None


def test_postorder_dfs():
Expand Down Expand Up @@ -75,15 +72,12 @@ def test_postorder_dfs():
{"Plan": 1, "Produces": 1, "MapJoin": 1, "Table": 2, "Literal": 2, "Field": 5}
)

pos = {}
for i, obj in enumerate(postorder):
k = id(obj)
if k in pos:
continue
pos[k] = i
for node in postorder:
for i, node in enumerate(postorder):
for child in getattr(node, "children", ()):
assert pos[id(child)] < pos[id(node)]
child_pos = next(
(j for j, n in enumerate(postorder) if n is child and j < i), None
)
assert child_pos is not None


def test_intree():
Expand Down
Loading