Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _This project uses semantic versioning_

## UNRELEASED

- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315)
- Automatically Create Changelog Entry for PRs [#313](https://github.com/egraphs-good/egglog-python/pull/313)
- Upgrade egglog which includes new backend.
- Fixes implementation of the Python Object sort to work with objects with dupliating hashes but the same value.
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since
- `__iter_`
- `__index__`

If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed
instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`,
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
method is defined instead of just attempting to call it.

### Reflected methods

Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.
Expand Down
3 changes: 2 additions & 1 deletion python/egglog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from . import config, ipython_magic # noqa: F401
from .bindings import EggSmolError # noqa: F401
from .builtins import * # noqa: UP029
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
from .conversion import *
from .egraph import *
from .runtime import define_expr_method as define_expr_method # noqa: PLC0414

del ipython_magic
90 changes: 48 additions & 42 deletions python/egglog/conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast

from .declarations import *
from .pretty import *
Expand All @@ -13,22 +14,22 @@
from .type_constraint_solver import TypeConstraintError

if TYPE_CHECKING:
from collections.abc import Callable, Generator
from collections.abc import Generator

from .egraph import BaseExpr
from .type_constraint_solver import TypeConstraintSolver

__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
# Global declerations to store all convertable types so we can query if they have certain methods or not
_CONVERSION_DECLS = Declarations.create()
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
# until we need them
_TO_PROCESS_DECLS: list[DeclerationsLike] = []


def _retrieve_conversion_decls() -> Declarations:
def retrieve_conversion_decls() -> Declarations:
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
_TO_PROCESS_DECLS.clear()
return _CONVERSION_DECLS
Expand All @@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
to_type_name = process_tp(to_type)
if not isinstance(to_type_name, JustTypeRef):
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
_register_converter(process_tp(from_type), to_type_name, fn, cost)
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)


def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
"""
Registers a converter from some type to an egglog type, if not already registered.

Expand Down Expand Up @@ -97,15 +98,15 @@ class _ComposedConverter:
We use the dataclass instead of the lambda to make it easier to debug.
"""

a_b: Callable
b_c: Callable
a_b: Callable[[Any], RuntimeExpr]
b_c: Callable[[Any], RuntimeExpr]
b_args: tuple[JustTypeRef, ...]

def __call__(self, x: object) -> object:
def __call__(self, x: Any) -> RuntimeExpr:
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
# when converting from A -> B
if self.b_args:
with with_type_args(self.b_args, _retrieve_conversion_decls):
with with_type_args(self.b_args, retrieve_conversion_decls):
first_res = self.a_b(x)
else:
first_res = self.a_b(x)
Expand Down Expand Up @@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
return tp


def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
"""
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
"""
decls = _retrieve_conversion_decls()
a_tp = _get_tp(a)
b_tp = _get_tp(b)
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
if not (
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
):
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
a_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
}
b_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
}
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
a_converts_to[a_tp] = 0
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
b_converts_to[b_tp] = 0
common = set(a_converts_to) & set(b_converts_to)
if not common:
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
# """
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
# """
# decls = _retrieve_conversion_decls().copy()
# if isinstance(a, RuntimeExpr):
# decls |= a
# if isinstance(b, RuntimeExpr):
# decls |= b

# a_tp = _get_tp(a)
# b_tp = _get_tp(b)
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
# if not (
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
# ):
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
# a_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
# }
# b_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
# }
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
# a_converts_to[a_tp] = 0
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
# b_converts_to[b_tp] = 0
# common = set(a_converts_to) & set(b_converts_to)
# if not common:
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])


def identity(x: object) -> object:
Expand Down Expand Up @@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
def resolve_literal(
tp: TypeOrVarRef,
arg: object,
decls: Callable[[], Declarations] = _retrieve_conversion_decls,
decls: Callable[[], Declarations] = retrieve_conversion_decls,
tcs: TypeConstraintSolver | None = None,
cls_name: str | None = None,
) -> RuntimeExpr:
Expand All @@ -208,12 +214,12 @@ def resolve_literal(

If it cannot be resolved, we assume that the value passed in will resolve it.
"""
arg_type = _get_tp(arg)
arg_type = resolve_type(arg)

# If we have any type variables, dont bother trying to resolve the literal, just return the arg
try:
tp_just = tp.to_just()
except NotImplementedError:
except TypeVarError:
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
# args first based on the existing type constraint solver
if tcs:
Expand Down Expand Up @@ -258,7 +264,7 @@ def _debug_print_converers():
source_to_targets[source].append(target)


def _get_tp(x: object) -> JustTypeRef | type:
def resolve_type(x: object) -> JustTypeRef | type:
if isinstance(x, RuntimeExpr):
return x.__egg_typed_expr__.tp
tp = type(x)
Expand Down
74 changes: 64 additions & 10 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"SpecialFunctions",
"TypeOrVarRef",
"TypeRefWithVars",
"TypeVarError",
"TypedExprDecl",
"UnboundVarDecl",
"UnionDecl",
Expand All @@ -95,7 +96,7 @@ def __egg_decls__(self) -> Declarations:
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
# instead raise explicitly
except AttributeError as err:
msg = f"Cannot resolve declarations for {self}"
msg = f"Cannot resolve declarations for {self}: {err}"
raise RuntimeError(msg) from err


Expand Down Expand Up @@ -225,14 +226,43 @@ def set_function_decl(
case _:
assert_never(ref)

def has_method(self, class_name: str, method_name: str) -> bool | None:
def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
"""
Returns whether the given class has the given method, or None if we cant find the class.
Checks if the class has a binary method compatible with the given types.
"""
if class_name in self._classes:
return method_name in self._classes[class_name].methods
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := self._classes[self_type.name].methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
vars, self_type
) and other_arg_type.matches_just(vars, other_type):
return True
return False

def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
"""
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
"""
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := self._classes[self_type.name].methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
return other_arg_type.to_just(vars)
return None

def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
"""
Returns the types which are compatible with the given binary method name and other type.
"""
for class_decl in self._classes.values():
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := class_decl.methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
vars, other_type
):
yield self_arg_type.to_just(vars)

def get_class_decl(self, name: str) -> ClassDecl:
return self._classes[name]

Expand Down Expand Up @@ -300,6 +330,10 @@ def __str__(self) -> str:
_RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}


class TypeVarError(RuntimeError):
"""Error when trying to resolve a type variable that doesn't exist."""


@dataclass(frozen=True)
class ClassTypeVarRef:
"""
Expand All @@ -309,9 +343,10 @@ class ClassTypeVarRef:
name: str
module: str

def to_just(self) -> JustTypeRef:
msg = f"{self}: egglog does not support generic classes yet."
raise NotImplementedError(msg)
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
if vars is None or self not in vars:
raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings")
return vars[self]

def __str__(self) -> str:
return str(self.to_type_var())
Expand All @@ -325,20 +360,39 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
def to_type_var(self) -> TypeVar:
return _RESOLVED_TYPEVARS[self]

def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
"""
Checks if this type variable matches the given JustTypeRef, including type variables.
"""
if self in vars:
return vars[self] == other
vars[self] = other
return True


@dataclass(frozen=True)
class TypeRefWithVars:
name: str
args: tuple[TypeOrVarRef, ...] = ()

def to_just(self) -> JustTypeRef:
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))

def __str__(self) -> str:
if self.args:
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
return self.name

def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
"""
Checks if this type reference matches the given JustTypeRef, including type variables.
"""
return (
self.name == other.name
and len(self.args) == len(other.args)
and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
)


TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars

Expand Down
2 changes: 1 addition & 1 deletion python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
"__firstlineno__",
"__static_attributes__",
# Ignore all reflected binary method
*REFLECTED_BINARY_METHODS.keys(),
*(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
}


Expand Down
Loading