Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since
- `__iter_`
- `__index__`

If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed
instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`,
with the name of a method. This is only needed for third party code that inspects the type object itself to see if a
method is defined instead of just attempting to call it.

### Reflected methods

Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.
Expand Down
3 changes: 2 additions & 1 deletion python/egglog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from . import config, ipython_magic # noqa: F401
from .bindings import EggSmolError # noqa: F401
from .builtins import * # noqa: UP029
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
from .conversion import *
from .egraph import *
from .runtime import define_expr_method as define_expr_method # noqa: PLC0414

del ipython_magic
90 changes: 48 additions & 42 deletions python/egglog/conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, Any, TypeVar, cast

from .declarations import *
from .pretty import *
Expand All @@ -13,22 +14,22 @@
from .type_constraint_solver import TypeConstraintError

if TYPE_CHECKING:
from collections.abc import Callable, Generator
from collections.abc import Generator

from .egraph import BaseExpr
from .type_constraint_solver import TypeConstraintSolver

__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
# Global declerations to store all convertable types so we can query if they have certain methods or not
_CONVERSION_DECLS = Declarations.create()
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
# until we need them
_TO_PROCESS_DECLS: list[DeclerationsLike] = []


def _retrieve_conversion_decls() -> Declarations:
def retrieve_conversion_decls() -> Declarations:
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
_TO_PROCESS_DECLS.clear()
return _CONVERSION_DECLS
Expand All @@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
to_type_name = process_tp(to_type)
if not isinstance(to_type_name, JustTypeRef):
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
_register_converter(process_tp(from_type), to_type_name, fn, cost)
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)


def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
"""
Registers a converter from some type to an egglog type, if not already registered.

Expand Down Expand Up @@ -97,15 +98,15 @@ class _ComposedConverter:
We use the dataclass instead of the lambda to make it easier to debug.
"""

a_b: Callable
b_c: Callable
a_b: Callable[[Any], RuntimeExpr]
b_c: Callable[[Any], RuntimeExpr]
b_args: tuple[JustTypeRef, ...]

def __call__(self, x: object) -> object:
def __call__(self, x: Any) -> RuntimeExpr:
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
# when converting from A -> B
if self.b_args:
with with_type_args(self.b_args, _retrieve_conversion_decls):
with with_type_args(self.b_args, retrieve_conversion_decls):
first_res = self.a_b(x)
else:
first_res = self.a_b(x)
Expand Down Expand Up @@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
return tp


def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
"""
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
"""
decls = _retrieve_conversion_decls()
a_tp = _get_tp(a)
b_tp = _get_tp(b)
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
if not (
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
):
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
a_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
}
b_converts_to = {
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
}
if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
a_converts_to[a_tp] = 0
if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
b_converts_to[b_tp] = 0
common = set(a_converts_to) & set(b_converts_to)
if not common:
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
# """
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
# """
# decls = _retrieve_conversion_decls().copy()
# if isinstance(a, RuntimeExpr):
# decls |= a
# if isinstance(b, RuntimeExpr):
# decls |= b

# a_tp = _get_tp(a)
# b_tp = _get_tp(b)
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
# if not (
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
# ):
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
# a_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
# }
# b_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
# }
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
# a_converts_to[a_tp] = 0
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
# b_converts_to[b_tp] = 0
# common = set(a_converts_to) & set(b_converts_to)
# if not common:
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])


def identity(x: object) -> object:
Expand Down Expand Up @@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
def resolve_literal(
tp: TypeOrVarRef,
arg: object,
decls: Callable[[], Declarations] = _retrieve_conversion_decls,
decls: Callable[[], Declarations] = retrieve_conversion_decls,
tcs: TypeConstraintSolver | None = None,
cls_name: str | None = None,
) -> RuntimeExpr:
Expand All @@ -208,12 +214,12 @@ def resolve_literal(

If it cannot be resolved, we assume that the value passed in will resolve it.
"""
arg_type = _get_tp(arg)
arg_type = resolve_type(arg)

# If we have any type variables, dont bother trying to resolve the literal, just return the arg
try:
tp_just = tp.to_just()
except NotImplementedError:
except TypeVarError:
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
# args first based on the existing type constraint solver
if tcs:
Expand Down Expand Up @@ -258,7 +264,7 @@ def _debug_print_converers():
source_to_targets[source].append(target)


def _get_tp(x: object) -> JustTypeRef | type:
def resolve_type(x: object) -> JustTypeRef | type:
if isinstance(x, RuntimeExpr):
return x.__egg_typed_expr__.tp
tp = type(x)
Expand Down
74 changes: 64 additions & 10 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"SpecialFunctions",
"TypeOrVarRef",
"TypeRefWithVars",
"TypeVarError",
"TypedExprDecl",
"UnboundVarDecl",
"UnionDecl",
Expand All @@ -95,7 +96,7 @@ def __egg_decls__(self) -> Declarations:
# Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
# instead raise explicitly
except AttributeError as err:
msg = f"Cannot resolve declarations for {self}"
msg = f"Cannot resolve declarations {err}"
Comment thread
saulshanabrook marked this conversation as resolved.
Outdated
raise RuntimeError(msg) from err


Expand Down Expand Up @@ -225,14 +226,43 @@ def set_function_decl(
case _:
assert_never(ref)

def has_method(self, class_name: str, method_name: str) -> bool | None:
def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
"""
Returns whether the given class has the given method, or None if we cant find the class.
Checks if the class has a binary method compatible with the given types.
"""
if class_name in self._classes:
return method_name in self._classes[class_name].methods
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := self._classes[self_type.name].methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
vars, self_type
) and other_arg_type.matches_just(vars, other_type):
return True
return False

def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
"""
Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
"""
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := self._classes[self_type.name].methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
return other_arg_type.to_just(vars)
return None

def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
"""
Returns the types which are compatible with the given binary method name and other type.
"""
for class_decl in self._classes.values():
vars: dict[ClassTypeVarRef, JustTypeRef] = {}
if callable_decl := class_decl.methods.get(method_name):
match callable_decl.signature:
case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
vars, other_type
):
yield self_arg_type.to_just(vars)

def get_class_decl(self, name: str) -> ClassDecl:
return self._classes[name]

Expand Down Expand Up @@ -300,6 +330,10 @@ def __str__(self) -> str:
_RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}


class TypeVarError(RuntimeError):
"""Error when trying to resolve a type variable that doesn't exist."""


@dataclass(frozen=True)
class ClassTypeVarRef:
"""
Expand All @@ -309,9 +343,10 @@ class ClassTypeVarRef:
name: str
module: str

def to_just(self) -> JustTypeRef:
msg = f"{self}: egglog does not support generic classes yet."
raise NotImplementedError(msg)
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
if vars is None or self not in vars:
raise TypeVarError(f"Cannot convert {self} to just type")
Comment thread
saulshanabrook marked this conversation as resolved.
Outdated
return vars[self]

def __str__(self) -> str:
return str(self.to_type_var())
Expand All @@ -325,20 +360,39 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
def to_type_var(self) -> TypeVar:
return _RESOLVED_TYPEVARS[self]

def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
"""
Checks if this type variable matches the given JustTypeRef, including type variables.
"""
if self in vars:
return vars[self] == other
vars[self] = other
return True


@dataclass(frozen=True)
class TypeRefWithVars:
name: str
args: tuple[TypeOrVarRef, ...] = ()

def to_just(self) -> JustTypeRef:
return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))

def __str__(self) -> str:
if self.args:
return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
return self.name

def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
"""
Checks if this type reference matches the given JustTypeRef, including type variables.
"""
return (
self.name == other.name
and len(self.args) == len(other.args)
and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
)


TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars

Expand Down
2 changes: 1 addition & 1 deletion python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
"__firstlineno__",
"__static_attributes__",
# Ignore all reflected binary method
*REFLECTED_BINARY_METHODS.keys(),
*(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS),
}


Expand Down
Loading
Loading