From ea917817a06c7932cc65f2775993e111b5261b9e Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 5 Aug 2025 15:44:08 -0400 Subject: [PATCH 1/3] Add support for adding methods actually defined on classes Adds support for methods like `__array_function__` which actually need to be added on the class as actual methods, not through overloading `__getattr__`. Custom methods can be registered by third party libraries. This PR also redoes the logic for upcasting when using binary operations. Instead of upcasting both values, it will only ever upcast one, choosing whichever one would be cheaper to upcast. This leads to more predictable behavior. --- docs/reference/python-integration.md | 5 + python/egglog/__init__.py | 3 +- python/egglog/conversion.py | 90 +++--- python/egglog/declarations.py | 74 ++++- python/egglog/egraph.py | 2 +- python/egglog/runtime.py | 299 ++++++++++-------- .../test_array_api/test_jit[lda][code].py | 4 +- .../test_array_api/test_jit[lda][expr].py | 17 +- .../test_jit[lda][initial_expr].py | 16 +- python/tests/test_high_level.py | 28 -- 10 files changed, 292 insertions(+), 246 deletions(-) diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index e1ca1a62..cd82d0a0 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since - `__iter_` - `__index__` +If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed +instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`, +with the name of a method. This is only needed for third party code that inspects the type object itself to see if a +method is defined instead of just attempting to call it. + ### Reflected methods Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions. diff --git a/python/egglog/__init__.py b/python/egglog/__init__.py index 0994b0d1..7e5344d3 100644 --- a/python/egglog/__init__.py +++ b/python/egglog/__init__.py @@ -5,7 +5,8 @@ from . import config, ipython_magic # noqa: F401 from .bindings import EggSmolError # noqa: F401 from .builtins import * # noqa: UP029 -from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401 +from .conversion import * from .egraph import * +from .runtime import define_expr_method as define_expr_method # noqa: PLC0414 del ipython_magic diff --git a/python/egglog/conversion.py b/python/egglog/conversion.py index 072b1688..4b2fc0d1 100644 --- a/python/egglog/conversion.py +++ b/python/egglog/conversion.py @@ -1,10 +1,11 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass -from typing import TYPE_CHECKING, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from .declarations import * from .pretty import * @@ -13,14 +14,14 @@ from .type_constraint_solver import TypeConstraintError if TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Generator from .egraph import BaseExpr from .type_constraint_solver import TypeConstraintSolver -__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"] +__all__ = ["ConvertError", "convert", "converter", "get_type_args"] # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target -CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {} +CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {} # Global declerations to store all convertable types so we can query if they have certain methods or not _CONVERSION_DECLS = Declarations.create() # Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing @@ -28,7 +29,7 @@ _TO_PROCESS_DECLS: list[DeclerationsLike] = [] -def _retrieve_conversion_decls() -> Declarations: +def retrieve_conversion_decls() -> Declarations: _CONVERSION_DECLS.update(*_TO_PROCESS_DECLS) _TO_PROCESS_DECLS.clear() return _CONVERSION_DECLS @@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: to_type_name = process_tp(to_type) if not isinstance(to_type_name, JustTypeRef): raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}") - _register_converter(process_tp(from_type), to_type_name, fn, cost) + _register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost) -def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None: +def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None: """ Registers a converter from some type to an egglog type, if not already registered. @@ -97,15 +98,15 @@ class _ComposedConverter: We use the dataclass instead of the lambda to make it easier to debug. """ - a_b: Callable - b_c: Callable + a_b: Callable[[Any], RuntimeExpr] + b_c: Callable[[Any], RuntimeExpr] b_args: tuple[JustTypeRef, ...] - def __call__(self, x: object) -> object: + def __call__(self, x: Any) -> RuntimeExpr: # if we have A -> B and B[C] -> D then we should use (C,) as the type args # when converting from A -> B if self.b_args: - with with_type_args(self.b_args, _retrieve_conversion_decls): + with with_type_args(self.b_args, retrieve_conversion_decls): first_res = self.a_b(x) else: first_res = self.a_b(x) @@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type: return tp -def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: - """ - Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists. - """ - decls = _retrieve_conversion_decls() - a_tp = _get_tp(a) - b_tp = _get_tp(b) - # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly - if not ( - (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name)) - or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name)) - ): - raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}") - a_converts_to = { - to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name) - } - b_converts_to = { - to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) - } - if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name): - a_converts_to[a_tp] = 0 - if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name): - b_converts_to[b_tp] = 0 - common = set(a_converts_to) & set(b_converts_to) - if not common: - raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type") - return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp]) +# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: +# """ +# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists. +# """ +# decls = _retrieve_conversion_decls().copy() +# if isinstance(a, RuntimeExpr): +# decls |= a +# if isinstance(b, RuntimeExpr): +# decls |= b + +# a_tp = _get_tp(a) +# b_tp = _get_tp(b) +# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly +# if not ( +# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name)) +# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name)) +# ): +# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}") +# a_converts_to = { +# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name) +# } +# b_converts_to = { +# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) +# } +# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name): +# a_converts_to[a_tp] = 0 +# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name): +# b_converts_to[b_tp] = 0 +# common = set(a_converts_to) & set(b_converts_to) +# if not common: +# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type") +# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp]) def identity(x: object) -> object: @@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio def resolve_literal( tp: TypeOrVarRef, arg: object, - decls: Callable[[], Declarations] = _retrieve_conversion_decls, + decls: Callable[[], Declarations] = retrieve_conversion_decls, tcs: TypeConstraintSolver | None = None, cls_name: str | None = None, ) -> RuntimeExpr: @@ -208,12 +214,12 @@ def resolve_literal( If it cannot be resolved, we assume that the value passed in will resolve it. """ - arg_type = _get_tp(arg) + arg_type = resolve_type(arg) # If we have any type variables, dont bother trying to resolve the literal, just return the arg try: tp_just = tp.to_just() - except NotImplementedError: + except TypeVarError: # If this is a generic arg but passed in a non runtime expression, try to resolve the generic # args first based on the existing type constraint solver if tcs: @@ -258,7 +264,7 @@ def _debug_print_converers(): source_to_targets[source].append(target) -def _get_tp(x: object) -> JustTypeRef | type: +def resolve_type(x: object) -> JustTypeRef | type: if isinstance(x, RuntimeExpr): return x.__egg_typed_expr__.tp tp = type(x) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 6725ccd2..571d210c 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -73,6 +73,7 @@ "SpecialFunctions", "TypeOrVarRef", "TypeRefWithVars", + "TypeVarError", "TypedExprDecl", "UnboundVarDecl", "UnionDecl", @@ -95,7 +96,7 @@ def __egg_decls__(self) -> Declarations: # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__` # instead raise explicitly except AttributeError as err: - msg = f"Cannot resolve declarations for {self}" + msg = f"Cannot resolve declarations {err}" raise RuntimeError(msg) from err @@ -225,14 +226,43 @@ def set_function_decl( case _: assert_never(ref) - def has_method(self, class_name: str, method_name: str) -> bool | None: + def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool: """ - Returns whether the given class has the given method, or None if we cant find the class. + Checks if the class has a binary method compatible with the given types. """ - if class_name in self._classes: - return method_name in self._classes[class_name].methods + vars: dict[ClassTypeVarRef, JustTypeRef] = {} + if callable_decl := self._classes[self_type.name].methods.get(method_name): + match callable_decl.signature: + case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just( + vars, self_type + ) and other_arg_type.matches_just(vars, other_type): + return True + return False + + def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None: + """ + Checks if the class has a binary method with the given name and self type. Returns the other type if it exists. + """ + vars: dict[ClassTypeVarRef, JustTypeRef] = {} + if callable_decl := self._classes[self_type.name].methods.get(method_name): + match callable_decl.signature: + case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type): + return other_arg_type.to_just(vars) return None + def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]: + """ + Returns the types which are compatible with the given binary method name and other type. + """ + for class_decl in self._classes.values(): + vars: dict[ClassTypeVarRef, JustTypeRef] = {} + if callable_decl := class_decl.methods.get(method_name): + match callable_decl.signature: + case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just( + vars, other_type + ): + yield self_arg_type.to_just(vars) + def get_class_decl(self, name: str) -> ClassDecl: return self._classes[name] @@ -300,6 +330,10 @@ def __str__(self) -> str: _RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {} +class TypeVarError(RuntimeError): + """Error when trying to resolve a type variable that doesn't exist.""" + + @dataclass(frozen=True) class ClassTypeVarRef: """ @@ -309,9 +343,10 @@ class ClassTypeVarRef: name: str module: str - def to_just(self) -> JustTypeRef: - msg = f"{self}: egglog does not support generic classes yet." - raise NotImplementedError(msg) + def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: + if vars is None or self not in vars: + raise TypeVarError(f"Cannot convert {self} to just type") + return vars[self] def __str__(self) -> str: return str(self.to_type_var()) @@ -325,20 +360,39 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef: def to_type_var(self) -> TypeVar: return _RESOLVED_TYPEVARS[self] + def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: + """ + Checks if this type variable matches the given JustTypeRef, including type variables. + """ + if self in vars: + return vars[self] == other + vars[self] = other + return True + @dataclass(frozen=True) class TypeRefWithVars: name: str args: tuple[TypeOrVarRef, ...] = () - def to_just(self) -> JustTypeRef: - return JustTypeRef(self.name, tuple(a.to_just() for a in self.args)) + def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: + return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args)) def __str__(self) -> str: if self.args: return f"{self.name}[{', '.join(str(a) for a in self.args)}]" return self.name + def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: + """ + Checks if this type reference matches the given JustTypeRef, including type variables. + """ + return ( + self.name == other.name + and len(self.args) == len(other.args) + and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True)) + ) + TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 90df1078..f973ac31 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -115,7 +115,7 @@ "__firstlineno__", "__static_attributes__", # Ignore all reflected binary method - *REFLECTED_BINARY_METHODS.keys(), + *(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS), } diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 487c3863..c3b78f82 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -11,12 +11,13 @@ from __future__ import annotations +import itertools import operator from collections.abc import Callable from dataclasses import dataclass, replace from inspect import Parameter, Signature from itertools import zip_longest -from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin from .declarations import * from .pretty import * @@ -27,15 +28,14 @@ if TYPE_CHECKING: from collections.abc import Iterable - from .egraph import Fact - __all__ = [ "LIT_CLASS_NAMES", - "REFLECTED_BINARY_METHODS", + "NUMERIC_BINARY_METHODS", "RuntimeClass", "RuntimeExpr", "RuntimeFunction", + "define_expr_method", "resolve_callable", "resolve_type_annotation", "resolve_type_annotation_mutate", @@ -46,24 +46,30 @@ UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"} LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"} -REFLECTED_BINARY_METHODS = { - "__radd__": "__add__", - "__rsub__": "__sub__", - "__rmul__": "__mul__", - "__rmatmul__": "__matmul__", - "__rtruediv__": "__truediv__", - "__rfloordiv__": "__floordiv__", - "__rmod__": "__mod__", - "__rpow__": "__pow__", - "__rlshift__": "__lshift__", - "__rrshift__": "__rshift__", - "__rand__": "__and__", - "__rxor__": "__xor__", - "__ror__": "__or__", +# All methods which should return NotImplemented if they fail to resolve and are reflected as well +# From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types + +NUMERIC_BINARY_METHODS = { + "__add__", + "__sub__", + "__mul__", + "__matmul__", + "__truediv__", + "__floordiv__", + "__mod__", + "__divmod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__xor__", + "__or__", } -# Methods that need to return real Python values not expressions -PRESERVED_METHODS = [ + +# Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods. + +TYPE_DEFINED_METHODS = { "__bool__", "__len__", "__complex__", @@ -71,9 +77,18 @@ "__float__", "__iter__", "__index__", - "__float__", - "__int__", -] + "__call__", + "__getitem__", + "__setitem__", + "__delitem__", + "__pos__", + "__neg__", + "__invert__", + "__lt__", + "__le__", + "__gt__", + "__ge__", +} # Set this globally so we can get access to PyObject when we have a type annotation of just object. # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically. @@ -288,6 +303,14 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((id(self.__egg_decls_thunk__), self.__egg_tp__)) + def __eq__(self, other: object) -> bool: + """ + Support equality for runtime comparison of egglog classes. + """ + if not isinstance(other, RuntimeClass): + return NotImplemented + return self.__egg_tp__ == other.__egg_tp__ + # Support unioning like types def __or__(self, value: type) -> object: return Union[self, value] # noqa: UP007 @@ -357,7 +380,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: try: bound = py_signature.bind(*args, **kwargs) except TypeError as err: - raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err + raise TypeError(f"Wrong number of arguments for {self} with args {args} and kwargs {kwargs}") from err del kwargs bound.apply_defaults() assert not bound.kwargs @@ -437,32 +460,6 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: return Signature(parameters) -# All methods which should return NotImplemented if they fail to resolve -# From https://docs.python.org/3/reference/datamodel.html -PARTIAL_METHODS = { - "__lt__", - "__le__", - "__eq__", - "__ne__", - "__gt__", - "__ge__", - "__add__", - "__sub__", - "__mul__", - "__matmul__", - "__truediv__", - "__floordiv__", - "__mod__", - "__divmod__", - "__pow__", - "__lshift__", - "__rshift__", - "__and__", - "__xor__", - "__or__", -} - - @dataclass class RuntimeExpr(DelayedDeclerations): __egg_typed_expr_thunk__: Callable[[], TypedExprDecl] @@ -479,17 +476,14 @@ def __egg_typed_expr__(self) -> TypedExprDecl: return self.__egg_typed_expr_thunk__() def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None: - cls_name = self.__egg_class_name__ - class_decl = self.__egg_class_decl__ - - if name in (preserved_methods := class_decl.preserved_methods): - return preserved_methods[name].__get__(self) - - if name in class_decl.methods: - return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self) - if name in class_decl.properties: - return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)() - raise AttributeError(f"{cls_name} has no method {name}") from None + if (method := _get_expr_method(self, name)) is not None: + return method + if name in self.__egg_class_decl__.properties: + fn = RuntimeFunction( + self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self + ) + return fn() + raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None def __repr__(self) -> str: """ @@ -520,13 +514,6 @@ def __egg_class_name__(self) -> str: def __egg_class_decl__(self) -> ClassDecl: return self.__egg_decls__.get_class_decl(self.__egg_class_name__) - # These both will be overriden below in the special methods section, but add these here for type hinting purposes - def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body] - ... - - def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body] - ... - # Implement these so that copy() works on this object # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion @@ -540,89 +527,127 @@ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None: def __hash__(self) -> int: return hash(self.__egg_typed_expr__) + # Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a + # preserved method or defined on the class + def __eq__(self, other: object) -> object: # type: ignore[override] + if (method := _get_expr_method(self, "__eq__")) is not None: + return method(other) -# Define each of the special methods, since we have already declared them for pretty printing -for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]: + # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other + # expr gets a chance to resolve __eq__ which could be a preserved method. + from .egraph import BaseExpr, eq - def _special_method( - self: RuntimeExpr, - *args: object, - __name: str = name, - **kwargs: object, - ) -> RuntimeExpr | Fact | None: - from .conversion import ConvertError + return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other)) - class_name = self.__egg_class_name__ - class_decl = self.__egg_class_decl__ - # First, try to resolve as preserved method - try: - method = class_decl.preserved_methods[__name] - except KeyError: - pass - else: - return method(self, *args, **kwargs) - # If this is a "partial" method meaning that it can return NotImplemented, - # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just - # using the arg type of the self arg. - # This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa. - if __name in PARTIAL_METHODS: - try: - return call_method_min_conversion(self, args[0], __name) - except ConvertError: - # Defer raising not imeplemented in case the dunder method is not symmetrical, then - # we use the standard process - pass - if __name in class_decl.methods: - fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self) - return fn(*args, **kwargs) # type: ignore[arg-type] - # Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly. - if __name == "__eq__": - from .egraph import BaseExpr, eq - - return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])) - if __name == "__ne__": - from .egraph import BaseExpr, ne - - return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))) - - if __name in PARTIAL_METHODS: - return NotImplemented - raise TypeError(f"{class_name!r} object does not support {__name}") + def __ne__(self, other: object) -> object: # type: ignore[override] + if (method := _get_expr_method(self, "__ne__")) is not None: + return method(other) - setattr(RuntimeExpr, name, _special_method) + from .egraph import BaseExpr, ne -# For each of the reflected binary methods, translate to the corresponding non-reflected method -for reflected, non_reflected in REFLECTED_BINARY_METHODS.items(): + return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other)) - def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = non_reflected) -> RuntimeExpr | None: - # All binary methods are also "partial" meaning we should try to upcast first. - return call_method_min_conversion(other, self, __non_reflected) + def __call__( + self, *args: object, **kwargs: object + ) -> object: # define it here only for type checking, it will be overriden below + ... - setattr(RuntimeExpr, reflected, _reflected_method) +def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None: + if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods): + return preserved_methods[name].__get__(expr) -def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None: - from .conversion import min_convertable_tp, resolve_literal + if name in expr.__egg_class_decl__.methods: + return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr) + return None - # find a minimum type that both can be converted to - # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats. - min_tp = min_convertable_tp(slf, other, name).to_var() - slf = resolve_literal(min_tp, slf) - other = resolve_literal(min_tp, other) - method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf) - return method(other) +def define_expr_method(name: str) -> None: + """ + Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method. -for name in PRESERVED_METHODS: + Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice, + like for NumPy's `__array_ufunc__`. + """ - def _preserved_method(self: RuntimeExpr, __name: str = name): - try: - method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name] - except KeyError as e: - raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}") from e - return method(self) + def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs): + fn = _get_expr_method(self, __name) + if fn is None: + raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}") + return fn(*args, **kwargs) + + setattr(RuntimeExpr, name, _defined_method) - setattr(RuntimeExpr, name, _preserved_method) + +for name in TYPE_DEFINED_METHODS: + define_expr_method(name) + + +for name, reversed in itertools.product(NUMERIC_BINARY_METHODS, (False, True)): + + def _numeric_binary_method(self: object, other: object, name: str = name, reversed: bool = reversed) -> object: + """ + Implements numeric binary operations. + + Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either + the LHS or the RHS as exactly the right type and then upcasting the other to that type. + """ + # 1. switch if reversed + if reversed: + self, other = other, self + # If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion + if not ( + isinstance(self, RuntimeExpr) + and isinstance(other, RuntimeExpr) + and ( + self.__egg_decls__.check_binary_method_with_types( + name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp + ) + ) + ): + from .conversion import CONVERSIONS, resolve_type, retrieve_conversion_decls + + # tuple of (cost, convert_self) + best_method: ( + tuple[ + int, + Callable[[Any], RuntimeExpr], + ] + | None + ) = None + # Start by checking if we have a LHS that matches exactly and a RHS which can be converted + if ( + isinstance(self, RuntimeExpr) + and ( + desired_other_type := self.__egg_decls__.check_binary_method_with_self_type( + name, self.__egg_typed_expr__.tp + ) + ) + and (converter := CONVERSIONS.get((resolve_type(other), desired_other_type))) + ): + best_method = (converter[0], lambda x: x) + + # Next see if it's possible to convert the LHS and keep the RHS as is + if isinstance(other, RuntimeExpr): + decls = retrieve_conversion_decls() + other_type = other.__egg_typed_expr__.tp + resolved_self_type = resolve_type(self) + for desired_self_type in decls.check_binary_method_with_other_type(name, other_type): + if converter := CONVERSIONS.get((resolved_self_type, desired_self_type)): + cost, convert_self = converter + if best_method is None or best_method[0] > cost: + best_method = (cost, convert_self) + + if not best_method: + raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found") + self = best_method[1](self) + + method_ref = MethodRef(self.__egg_class_name__, name) + fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self) + return fn(other) + + method_name = f"__r{name[2:]}" if reversed else name + setattr(RuntimeExpr, method_name, _numeric_binary_method) def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py index 7bb9f1ea..f4942a8b 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py @@ -25,7 +25,7 @@ def __fn(X, y): _8[2, :,] = _14 _15 = _7 @ _8 _16 = X - _15 - _17 = np.sqrt(np.asarray(np.array(float(1 / 147)), np.dtype(np.float64))) + _17 = np.sqrt(np.asarray(np.array((float(1) / 147)), np.dtype(np.float64))) _18 = X[_0] - _8[0, :,] _19 = X[_2] - _8[1, :,] _20 = X[_4] - _8[2, :,] @@ -49,7 +49,7 @@ def __fn(X, y): _37 = _33[2][:_36, :,] / _29 _38 = _37.T / _33[1][:_36] _39 = np.array(150) * _7 - _40 = _39 * np.array(float(1 / 2)) + _40 = _39 * np.array((float(1) / 2)) _41 = np.sqrt(_40) _42 = _8 - _15 _43 = _41 * _42.T diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py index a1457ca4..18fe5202 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py @@ -30,31 +30,22 @@ _IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) _NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] _NDArray_4[_IndexKey_3] = sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)])) +_Value_1 = Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) _NDArray_8 = concat( TupleNDArray.from_vec(Vec[NDArray](_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0)) ) _NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))) _NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)]))) _NDArray_11 = copy(_NDArray_10) -_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar( - Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) -) -_TupleNDArray_1 = svd( - sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64))) - * (_NDArray_8 / _NDArray_11), - Boolean(False), -) +_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(_Value_1) +_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.int(_Value_1.to_int / Int(147))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), Boolean(False)) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) _NDArray_12 = ( _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_11 ).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] _TupleNDArray_2 = svd( - ( - sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) - * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T - ).T - @ _NDArray_12, + (sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.int(_Value_1.to_int / Int(2)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_12, Boolean(False), ) ( diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py index 0234b17d..1ee8a0ab 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py @@ -23,6 +23,7 @@ _IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) _IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2)))) _NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1) +_Value_1 = Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) _NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) _IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) _NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1) @@ -41,14 +42,8 @@ OptionalInt.some(Int(0)), ) _NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1) -_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar( - Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) -) -_TupleNDArray_1 = svd( - sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64))) - * (_NDArray_5 / _NDArray_6), - Boolean(False), -) +_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(_Value_1) +_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.int(_Value_1.to_int / Int(147))), OptionalDType.some(DType.float64))) * (_NDArray_5 / _NDArray_6), Boolean(False)) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) _NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1))))) _NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7)))))) @@ -69,10 +64,7 @@ _NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) _NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1"))))) _TupleNDArray_2 = svd( - ( - sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) - * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T - ).T + (sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.int(_Value_1.to_int / Int(2)))) * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T).T @ ( ( _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index e8aa1c34..2165f832 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -384,34 +384,6 @@ def __radd__(self, other: Math) -> Math: ... ) -def test_upcast_args(): - # -0.1 + Int(x) -> -0.1 + Float(x) - EGraph() - - class Int(Expr): - def __init__(self, value: i64Like) -> None: ... - - def __add__(self, other: Int) -> Int: ... - - class Float(Expr): - def __init__(self, value: f64Like) -> None: ... - - def __add__(self, other: Float) -> Float: ... - - @classmethod - def from_int(cls, other: Int) -> Float: ... - - converter(i64, Int, Int) - converter(f64, Float, Float) - converter(Int, Float, Float.from_int) - - res: Expr = -0.1 + Int(10) # type: ignore[operator,assignment] - assert expr_parts(res) == expr_parts(Float(-0.1) + Float.from_int(Int(10))) - - res: Expr = Int(10) + -0.1 # type: ignore[operator,assignment] - assert expr_parts(res) == expr_parts(Float.from_int(Int(10)) + Float(-0.1)) - - def test_rewrite_upcasts(): class X(Expr): def __init__(self, value: i64Like) -> None: ... From 90da0cf3d3675239abe21b94412b1341b32d86f9 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 5 Aug 2025 15:49:45 -0400 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- python/egglog/declarations.py | 4 ++-- python/egglog/runtime.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 571d210c..24512704 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -96,7 +96,7 @@ def __egg_decls__(self) -> Declarations: # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__` # instead raise explicitly except AttributeError as err: - msg = f"Cannot resolve declarations {err}" + msg = f"Cannot resolve declarations for {self}: {err}" raise RuntimeError(msg) from err @@ -345,7 +345,7 @@ class ClassTypeVarRef: def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: if vars is None or self not in vars: - raise TypeVarError(f"Cannot convert {self} to just type") + raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings") return vars[self] def __str__(self) -> str: diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index c3b78f82..f43d51f1 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -380,7 +380,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: try: bound = py_signature.bind(*args, **kwargs) except TypeError as err: - raise TypeError(f"Wrong number of arguments for {self} with args {args} and kwargs {kwargs}") from err + raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err del kwargs bound.apply_defaults() assert not bound.kwargs From 34fef64578326f28cd859296524e850fdf9d49a0 Mon Sep 17 00:00:00 2001 From: GitHub Action Date: Tue, 5 Aug 2025 19:52:38 +0000 Subject: [PATCH 3/3] Add changelog entry for PR #315 --- docs/changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.md b/docs/changelog.md index efbe4a11..7df8c05d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315) - Automatically Create Changelog Entry for PRs [#313](https://github.com/egraphs-good/egglog-python/pull/313) - Upgrade egglog which includes new backend. - Fixes implementation of the Python Object sort to work with objects with dupliating hashes but the same value.