diff --git a/docs/changelog.md b/docs/changelog.md index efbe4a11..7df8c05d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Support methods like on expressions [#315](https://github.com/egraphs-good/egglog-python/pull/315) - Automatically Create Changelog Entry for PRs [#313](https://github.com/egraphs-good/egglog-python/pull/313) - Upgrade egglog which includes new backend. - Fixes implementation of the Python Object sort to work with objects with dupliating hashes but the same value. diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index e1ca1a62..cd82d0a0 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -303,6 +303,11 @@ Note that the following list of methods are only supported as "preserved" since - `__iter_` - `__index__` +If you want to register additional methods as always preserved and defined on the `Expr` class itself, if needed +instead of the normal mechanism which relies on `__getattr__`, you can call `egglog.define_expr_method(name: str)`, +with the name of a method. This is only needed for third party code that inspects the type object itself to see if a +method is defined instead of just attempting to call it. + ### Reflected methods Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions. diff --git a/python/egglog/__init__.py b/python/egglog/__init__.py index 0994b0d1..7e5344d3 100644 --- a/python/egglog/__init__.py +++ b/python/egglog/__init__.py @@ -5,7 +5,8 @@ from . import config, ipython_magic # noqa: F401 from .bindings import EggSmolError # noqa: F401 from .builtins import * # noqa: UP029 -from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401 +from .conversion import * from .egraph import * +from .runtime import define_expr_method as define_expr_method # noqa: PLC0414 del ipython_magic diff --git a/python/egglog/conversion.py b/python/egglog/conversion.py index 072b1688..4b2fc0d1 100644 --- a/python/egglog/conversion.py +++ b/python/egglog/conversion.py @@ -1,10 +1,11 @@ from __future__ import annotations from collections import defaultdict +from collections.abc import Callable from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass -from typing import TYPE_CHECKING, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast from .declarations import * from .pretty import * @@ -13,14 +14,14 @@ from .type_constraint_solver import TypeConstraintError if TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Generator from .egraph import BaseExpr from .type_constraint_solver import TypeConstraintSolver -__all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"] +__all__ = ["ConvertError", "convert", "converter", "get_type_args"] # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target -CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {} +CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {} # Global declerations to store all convertable types so we can query if they have certain methods or not _CONVERSION_DECLS = Declarations.create() # Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing @@ -28,7 +29,7 @@ _TO_PROCESS_DECLS: list[DeclerationsLike] = [] -def _retrieve_conversion_decls() -> Declarations: +def retrieve_conversion_decls() -> Declarations: _CONVERSION_DECLS.update(*_TO_PROCESS_DECLS) _TO_PROCESS_DECLS.clear() return _CONVERSION_DECLS @@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: to_type_name = process_tp(to_type) if not isinstance(to_type_name, JustTypeRef): raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}") - _register_converter(process_tp(from_type), to_type_name, fn, cost) + _register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost) -def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None: +def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None: """ Registers a converter from some type to an egglog type, if not already registered. @@ -97,15 +98,15 @@ class _ComposedConverter: We use the dataclass instead of the lambda to make it easier to debug. """ - a_b: Callable - b_c: Callable + a_b: Callable[[Any], RuntimeExpr] + b_c: Callable[[Any], RuntimeExpr] b_args: tuple[JustTypeRef, ...] - def __call__(self, x: object) -> object: + def __call__(self, x: Any) -> RuntimeExpr: # if we have A -> B and B[C] -> D then we should use (C,) as the type args # when converting from A -> B if self.b_args: - with with_type_args(self.b_args, _retrieve_conversion_decls): + with with_type_args(self.b_args, retrieve_conversion_decls): first_res = self.a_b(x) else: first_res = self.a_b(x) @@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type: return tp -def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: - """ - Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists. - """ - decls = _retrieve_conversion_decls() - a_tp = _get_tp(a) - b_tp = _get_tp(b) - # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly - if not ( - (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name)) - or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name)) - ): - raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}") - a_converts_to = { - to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name) - } - b_converts_to = { - to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) - } - if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name): - a_converts_to[a_tp] = 0 - if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name): - b_converts_to[b_tp] = 0 - common = set(a_converts_to) & set(b_converts_to) - if not common: - raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type") - return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp]) +# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef: +# """ +# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists. +# """ +# decls = _retrieve_conversion_decls().copy() +# if isinstance(a, RuntimeExpr): +# decls |= a +# if isinstance(b, RuntimeExpr): +# decls |= b + +# a_tp = _get_tp(a) +# b_tp = _get_tp(b) +# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly +# if not ( +# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name)) +# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name)) +# ): +# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}") +# a_converts_to = { +# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name) +# } +# b_converts_to = { +# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name) +# } +# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name): +# a_converts_to[a_tp] = 0 +# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name): +# b_converts_to[b_tp] = 0 +# common = set(a_converts_to) & set(b_converts_to) +# if not common: +# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type") +# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp]) def identity(x: object) -> object: @@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio def resolve_literal( tp: TypeOrVarRef, arg: object, - decls: Callable[[], Declarations] = _retrieve_conversion_decls, + decls: Callable[[], Declarations] = retrieve_conversion_decls, tcs: TypeConstraintSolver | None = None, cls_name: str | None = None, ) -> RuntimeExpr: @@ -208,12 +214,12 @@ def resolve_literal( If it cannot be resolved, we assume that the value passed in will resolve it. """ - arg_type = _get_tp(arg) + arg_type = resolve_type(arg) # If we have any type variables, dont bother trying to resolve the literal, just return the arg try: tp_just = tp.to_just() - except NotImplementedError: + except TypeVarError: # If this is a generic arg but passed in a non runtime expression, try to resolve the generic # args first based on the existing type constraint solver if tcs: @@ -258,7 +264,7 @@ def _debug_print_converers(): source_to_targets[source].append(target) -def _get_tp(x: object) -> JustTypeRef | type: +def resolve_type(x: object) -> JustTypeRef | type: if isinstance(x, RuntimeExpr): return x.__egg_typed_expr__.tp tp = type(x) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 6725ccd2..24512704 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -73,6 +73,7 @@ "SpecialFunctions", "TypeOrVarRef", "TypeRefWithVars", + "TypeVarError", "TypedExprDecl", "UnboundVarDecl", "UnionDecl", @@ -95,7 +96,7 @@ def __egg_decls__(self) -> Declarations: # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__` # instead raise explicitly except AttributeError as err: - msg = f"Cannot resolve declarations for {self}" + msg = f"Cannot resolve declarations for {self}: {err}" raise RuntimeError(msg) from err @@ -225,14 +226,43 @@ def set_function_decl( case _: assert_never(ref) - def has_method(self, class_name: str, method_name: str) -> bool | None: + def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool: """ - Returns whether the given class has the given method, or None if we cant find the class. + Checks if the class has a binary method compatible with the given types. """ - if class_name in self._classes: - return method_name in self._classes[class_name].methods + vars: dict[ClassTypeVarRef, JustTypeRef] = {} + if callable_decl := self._classes[self_type.name].methods.get(method_name): + match callable_decl.signature: + case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just( + vars, self_type + ) and other_arg_type.matches_just(vars, other_type): + return True + return False + + def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None: + """ + Checks if the class has a binary method with the given name and self type. Returns the other type if it exists. + """ + vars: dict[ClassTypeVarRef, JustTypeRef] = {} + if callable_decl := self._classes[self_type.name].methods.get(method_name): + match callable_decl.signature: + case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type): + return other_arg_type.to_just(vars) return None + def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]: + """ + Returns the types which are compatible with the given binary method name and other type. + """ + for class_decl in self._classes.values(): + vars: dict[ClassTypeVarRef, JustTypeRef] = {} + if callable_decl := class_decl.methods.get(method_name): + match callable_decl.signature: + case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just( + vars, other_type + ): + yield self_arg_type.to_just(vars) + def get_class_decl(self, name: str) -> ClassDecl: return self._classes[name] @@ -300,6 +330,10 @@ def __str__(self) -> str: _RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {} +class TypeVarError(RuntimeError): + """Error when trying to resolve a type variable that doesn't exist.""" + + @dataclass(frozen=True) class ClassTypeVarRef: """ @@ -309,9 +343,10 @@ class ClassTypeVarRef: name: str module: str - def to_just(self) -> JustTypeRef: - msg = f"{self}: egglog does not support generic classes yet." - raise NotImplementedError(msg) + def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: + if vars is None or self not in vars: + raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings") + return vars[self] def __str__(self) -> str: return str(self.to_type_var()) @@ -325,20 +360,39 @@ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef: def to_type_var(self) -> TypeVar: return _RESOLVED_TYPEVARS[self] + def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: + """ + Checks if this type variable matches the given JustTypeRef, including type variables. + """ + if self in vars: + return vars[self] == other + vars[self] = other + return True + @dataclass(frozen=True) class TypeRefWithVars: name: str args: tuple[TypeOrVarRef, ...] = () - def to_just(self) -> JustTypeRef: - return JustTypeRef(self.name, tuple(a.to_just() for a in self.args)) + def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef: + return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args)) def __str__(self) -> str: if self.args: return f"{self.name}[{', '.join(str(a) for a in self.args)}]" return self.name + def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool: + """ + Checks if this type reference matches the given JustTypeRef, including type variables. + """ + return ( + self.name == other.name + and len(self.args) == len(other.args) + and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True)) + ) + TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 90df1078..f973ac31 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -115,7 +115,7 @@ "__firstlineno__", "__static_attributes__", # Ignore all reflected binary method - *REFLECTED_BINARY_METHODS.keys(), + *(f"__r{m[2:]}" for m in NUMERIC_BINARY_METHODS), } diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 487c3863..f43d51f1 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -11,12 +11,13 @@ from __future__ import annotations +import itertools import operator from collections.abc import Callable from dataclasses import dataclass, replace from inspect import Parameter, Signature from itertools import zip_longest -from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin from .declarations import * from .pretty import * @@ -27,15 +28,14 @@ if TYPE_CHECKING: from collections.abc import Iterable - from .egraph import Fact - __all__ = [ "LIT_CLASS_NAMES", - "REFLECTED_BINARY_METHODS", + "NUMERIC_BINARY_METHODS", "RuntimeClass", "RuntimeExpr", "RuntimeFunction", + "define_expr_method", "resolve_callable", "resolve_type_annotation", "resolve_type_annotation_mutate", @@ -46,24 +46,30 @@ UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"} LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"} -REFLECTED_BINARY_METHODS = { - "__radd__": "__add__", - "__rsub__": "__sub__", - "__rmul__": "__mul__", - "__rmatmul__": "__matmul__", - "__rtruediv__": "__truediv__", - "__rfloordiv__": "__floordiv__", - "__rmod__": "__mod__", - "__rpow__": "__pow__", - "__rlshift__": "__lshift__", - "__rrshift__": "__rshift__", - "__rand__": "__and__", - "__rxor__": "__xor__", - "__ror__": "__or__", +# All methods which should return NotImplemented if they fail to resolve and are reflected as well +# From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types + +NUMERIC_BINARY_METHODS = { + "__add__", + "__sub__", + "__mul__", + "__matmul__", + "__truediv__", + "__floordiv__", + "__mod__", + "__divmod__", + "__pow__", + "__lshift__", + "__rshift__", + "__and__", + "__xor__", + "__or__", } -# Methods that need to return real Python values not expressions -PRESERVED_METHODS = [ + +# Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods. + +TYPE_DEFINED_METHODS = { "__bool__", "__len__", "__complex__", @@ -71,9 +77,18 @@ "__float__", "__iter__", "__index__", - "__float__", - "__int__", -] + "__call__", + "__getitem__", + "__setitem__", + "__delitem__", + "__pos__", + "__neg__", + "__invert__", + "__lt__", + "__le__", + "__gt__", + "__ge__", +} # Set this globally so we can get access to PyObject when we have a type annotation of just object. # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically. @@ -288,6 +303,14 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash((id(self.__egg_decls_thunk__), self.__egg_tp__)) + def __eq__(self, other: object) -> bool: + """ + Support equality for runtime comparison of egglog classes. + """ + if not isinstance(other, RuntimeClass): + return NotImplemented + return self.__egg_tp__ == other.__egg_tp__ + # Support unioning like types def __or__(self, value: type) -> object: return Union[self, value] # noqa: UP007 @@ -357,7 +380,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: try: bound = py_signature.bind(*args, **kwargs) except TypeError as err: - raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err + raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err del kwargs bound.apply_defaults() assert not bound.kwargs @@ -437,32 +460,6 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: return Signature(parameters) -# All methods which should return NotImplemented if they fail to resolve -# From https://docs.python.org/3/reference/datamodel.html -PARTIAL_METHODS = { - "__lt__", - "__le__", - "__eq__", - "__ne__", - "__gt__", - "__ge__", - "__add__", - "__sub__", - "__mul__", - "__matmul__", - "__truediv__", - "__floordiv__", - "__mod__", - "__divmod__", - "__pow__", - "__lshift__", - "__rshift__", - "__and__", - "__xor__", - "__or__", -} - - @dataclass class RuntimeExpr(DelayedDeclerations): __egg_typed_expr_thunk__: Callable[[], TypedExprDecl] @@ -479,17 +476,14 @@ def __egg_typed_expr__(self) -> TypedExprDecl: return self.__egg_typed_expr_thunk__() def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None: - cls_name = self.__egg_class_name__ - class_decl = self.__egg_class_decl__ - - if name in (preserved_methods := class_decl.preserved_methods): - return preserved_methods[name].__get__(self) - - if name in class_decl.methods: - return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self) - if name in class_decl.properties: - return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)() - raise AttributeError(f"{cls_name} has no method {name}") from None + if (method := _get_expr_method(self, name)) is not None: + return method + if name in self.__egg_class_decl__.properties: + fn = RuntimeFunction( + self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self + ) + return fn() + raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None def __repr__(self) -> str: """ @@ -520,13 +514,6 @@ def __egg_class_name__(self) -> str: def __egg_class_decl__(self) -> ClassDecl: return self.__egg_decls__.get_class_decl(self.__egg_class_name__) - # These both will be overriden below in the special methods section, but add these here for type hinting purposes - def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body] - ... - - def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body] - ... - # Implement these so that copy() works on this object # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion @@ -540,89 +527,127 @@ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None: def __hash__(self) -> int: return hash(self.__egg_typed_expr__) + # Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a + # preserved method or defined on the class + def __eq__(self, other: object) -> object: # type: ignore[override] + if (method := _get_expr_method(self, "__eq__")) is not None: + return method(other) -# Define each of the special methods, since we have already declared them for pretty printing -for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]: + # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other + # expr gets a chance to resolve __eq__ which could be a preserved method. + from .egraph import BaseExpr, eq - def _special_method( - self: RuntimeExpr, - *args: object, - __name: str = name, - **kwargs: object, - ) -> RuntimeExpr | Fact | None: - from .conversion import ConvertError + return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other)) - class_name = self.__egg_class_name__ - class_decl = self.__egg_class_decl__ - # First, try to resolve as preserved method - try: - method = class_decl.preserved_methods[__name] - except KeyError: - pass - else: - return method(self, *args, **kwargs) - # If this is a "partial" method meaning that it can return NotImplemented, - # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just - # using the arg type of the self arg. - # This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa. - if __name in PARTIAL_METHODS: - try: - return call_method_min_conversion(self, args[0], __name) - except ConvertError: - # Defer raising not imeplemented in case the dunder method is not symmetrical, then - # we use the standard process - pass - if __name in class_decl.methods: - fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self) - return fn(*args, **kwargs) # type: ignore[arg-type] - # Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly. - if __name == "__eq__": - from .egraph import BaseExpr, eq - - return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])) - if __name == "__ne__": - from .egraph import BaseExpr, ne - - return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))) - - if __name in PARTIAL_METHODS: - return NotImplemented - raise TypeError(f"{class_name!r} object does not support {__name}") + def __ne__(self, other: object) -> object: # type: ignore[override] + if (method := _get_expr_method(self, "__ne__")) is not None: + return method(other) - setattr(RuntimeExpr, name, _special_method) + from .egraph import BaseExpr, ne -# For each of the reflected binary methods, translate to the corresponding non-reflected method -for reflected, non_reflected in REFLECTED_BINARY_METHODS.items(): + return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other)) - def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = non_reflected) -> RuntimeExpr | None: - # All binary methods are also "partial" meaning we should try to upcast first. - return call_method_min_conversion(other, self, __non_reflected) + def __call__( + self, *args: object, **kwargs: object + ) -> object: # define it here only for type checking, it will be overriden below + ... - setattr(RuntimeExpr, reflected, _reflected_method) +def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None: + if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods): + return preserved_methods[name].__get__(expr) -def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None: - from .conversion import min_convertable_tp, resolve_literal + if name in expr.__egg_class_decl__.methods: + return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr) + return None - # find a minimum type that both can be converted to - # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats. - min_tp = min_convertable_tp(slf, other, name).to_var() - slf = resolve_literal(min_tp, slf) - other = resolve_literal(min_tp, other) - method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf) - return method(other) +def define_expr_method(name: str) -> None: + """ + Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method. -for name in PRESERVED_METHODS: + Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice, + like for NumPy's `__array_ufunc__`. + """ - def _preserved_method(self: RuntimeExpr, __name: str = name): - try: - method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name] - except KeyError as e: - raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}") from e - return method(self) + def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs): + fn = _get_expr_method(self, __name) + if fn is None: + raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}") + return fn(*args, **kwargs) + + setattr(RuntimeExpr, name, _defined_method) - setattr(RuntimeExpr, name, _preserved_method) + +for name in TYPE_DEFINED_METHODS: + define_expr_method(name) + + +for name, reversed in itertools.product(NUMERIC_BINARY_METHODS, (False, True)): + + def _numeric_binary_method(self: object, other: object, name: str = name, reversed: bool = reversed) -> object: + """ + Implements numeric binary operations. + + Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either + the LHS or the RHS as exactly the right type and then upcasting the other to that type. + """ + # 1. switch if reversed + if reversed: + self, other = other, self + # If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion + if not ( + isinstance(self, RuntimeExpr) + and isinstance(other, RuntimeExpr) + and ( + self.__egg_decls__.check_binary_method_with_types( + name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp + ) + ) + ): + from .conversion import CONVERSIONS, resolve_type, retrieve_conversion_decls + + # tuple of (cost, convert_self) + best_method: ( + tuple[ + int, + Callable[[Any], RuntimeExpr], + ] + | None + ) = None + # Start by checking if we have a LHS that matches exactly and a RHS which can be converted + if ( + isinstance(self, RuntimeExpr) + and ( + desired_other_type := self.__egg_decls__.check_binary_method_with_self_type( + name, self.__egg_typed_expr__.tp + ) + ) + and (converter := CONVERSIONS.get((resolve_type(other), desired_other_type))) + ): + best_method = (converter[0], lambda x: x) + + # Next see if it's possible to convert the LHS and keep the RHS as is + if isinstance(other, RuntimeExpr): + decls = retrieve_conversion_decls() + other_type = other.__egg_typed_expr__.tp + resolved_self_type = resolve_type(self) + for desired_self_type in decls.check_binary_method_with_other_type(name, other_type): + if converter := CONVERSIONS.get((resolved_self_type, desired_self_type)): + cost, convert_self = converter + if best_method is None or best_method[0] > cost: + best_method = (cost, convert_self) + + if not best_method: + raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found") + self = best_method[1](self) + + method_ref = MethodRef(self.__egg_class_name__, name) + fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self) + return fn(other) + + method_name = f"__r{name[2:]}" if reversed else name + setattr(RuntimeExpr, method_name, _numeric_binary_method) def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py index 7bb9f1ea..f4942a8b 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][code].py @@ -25,7 +25,7 @@ def __fn(X, y): _8[2, :,] = _14 _15 = _7 @ _8 _16 = X - _15 - _17 = np.sqrt(np.asarray(np.array(float(1 / 147)), np.dtype(np.float64))) + _17 = np.sqrt(np.asarray(np.array((float(1) / 147)), np.dtype(np.float64))) _18 = X[_0] - _8[0, :,] _19 = X[_2] - _8[1, :,] _20 = X[_4] - _8[2, :,] @@ -49,7 +49,7 @@ def __fn(X, y): _37 = _33[2][:_36, :,] / _29 _38 = _37.T / _33[1][:_36] _39 = np.array(150) * _7 - _40 = _39 * np.array(float(1 / 2)) + _40 = _39 * np.array((float(1) / 2)) _41 = np.sqrt(_40) _42 = _8 - _15 _43 = _41 * _42.T diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py index a1457ca4..18fe5202 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][expr].py @@ -30,31 +30,22 @@ _IndexKey_3 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) _NDArray_7 = _NDArray_1[IndexKey.ndarray(_NDArray_2 == NDArray.scalar(Value.int(Int(2))))] _NDArray_4[_IndexKey_3] = sum(_NDArray_7, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_7.shape[Int(0)])) +_Value_1 = Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) _NDArray_8 = concat( TupleNDArray.from_vec(Vec[NDArray](_NDArray_5 - _NDArray_4[_IndexKey_1], _NDArray_6 - _NDArray_4[_IndexKey_2], _NDArray_7 - _NDArray_4[_IndexKey_3])), OptionalInt.some(Int(0)) ) _NDArray_9 = square(_NDArray_8 - expand_dims(sum(_NDArray_8, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_8.shape[Int(0)])))) _NDArray_10 = sqrt(sum(_NDArray_9, _OptionalIntOrTuple_1) / NDArray.scalar(Value.int(_NDArray_9.shape[Int(0)]))) _NDArray_11 = copy(_NDArray_10) -_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar( - Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) -) -_TupleNDArray_1 = svd( - sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64))) - * (_NDArray_8 / _NDArray_11), - Boolean(False), -) +_NDArray_11[IndexKey.ndarray(_NDArray_10 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(_Value_1) +_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.int(_Value_1.to_int / Int(147))), OptionalDType.some(DType.float64))) * (_NDArray_8 / _NDArray_11), Boolean(False)) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) _NDArray_12 = ( _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] / _NDArray_11 ).T / _TupleNDArray_1[Int(1)][IndexKey.slice(_Slice_1)] _TupleNDArray_2 = svd( - ( - sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) - * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T - ).T - @ _NDArray_12, + (sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_3) * NDArray.scalar(Value.int(_Value_1.to_int / Int(2)))) * (_NDArray_4 - (_NDArray_3 @ _NDArray_4)).T).T @ _NDArray_12, Boolean(False), ) ( diff --git a/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py index 0234b17d..1ee8a0ab 100644 --- a/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py +++ b/python/tests/__snapshots__/test_array_api/test_jit[lda][initial_expr].py @@ -23,6 +23,7 @@ _IndexKey_5 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec(MultiAxisIndexKeyItem.int(Int(2)), _MultiAxisIndexKeyItem_1))) _IndexKey_6 = IndexKey.ndarray(unique_inverse(_NDArray_2)[Int(1)] == NDArray.scalar(Value.int(Int(2)))) _NDArray_3[_IndexKey_5] = mean(asarray(_NDArray_1)[_IndexKey_6], _OptionalIntOrTuple_1) +_Value_1 = Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) _NDArray_4 = zeros(TupleInt.from_vec(Vec[Int](Int(3), Int(4))), OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device)) _IndexKey_7 = IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.int(Int(0)), _MultiAxisIndexKeyItem_1))) _NDArray_4[_IndexKey_7] = mean(_NDArray_1[_IndexKey_2], _OptionalIntOrTuple_1) @@ -41,14 +42,8 @@ OptionalInt.some(Int(0)), ) _NDArray_6 = std(_NDArray_5, _OptionalIntOrTuple_1) -_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar( - Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("1")))) -) -_TupleNDArray_1 = svd( - sqrt(asarray(NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("147"))))), OptionalDType.some(DType.float64))) - * (_NDArray_5 / _NDArray_6), - Boolean(False), -) +_NDArray_6[IndexKey.ndarray(std(_NDArray_5, _OptionalIntOrTuple_1) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(_Value_1) +_TupleNDArray_1 = svd(sqrt(asarray(NDArray.scalar(Value.int(_Value_1.to_int / Int(147))), OptionalDType.some(DType.float64))) * (_NDArray_5 / _NDArray_6), Boolean(False)) _Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_value().to_int)) _NDArray_7 = asarray(reshape(asarray(_NDArray_2), TupleInt.from_vec(Vec[Int](Int(-1))))) _NDArray_8 = unique_values(concat(TupleNDArray.from_vec(Vec[NDArray](unique_values(asarray(_NDArray_7)))))) @@ -69,10 +64,7 @@ _NDArray_10[IndexKey.ndarray(_NDArray_9 == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) _NDArray_11 = astype(unique_counts(_NDArray_2)[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("150"), BigInt.from_string("1"))))) _TupleNDArray_2 = svd( - ( - sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.float(Float.rational(BigRat(BigInt.from_string("1"), BigInt.from_string("2")))))) - * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T - ).T + (sqrt((NDArray.scalar(Value.int(Int(150))) * _NDArray_11) * NDArray.scalar(Value.int(_Value_1.to_int / Int(2)))) * (_NDArray_4 - (_NDArray_11 @ _NDArray_4)).T).T @ ( ( _TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey.from_vec(Vec[MultiAxisIndexKeyItem](MultiAxisIndexKeyItem.slice(_Slice_1), _MultiAxisIndexKeyItem_1)))] diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index e8aa1c34..2165f832 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -384,34 +384,6 @@ def __radd__(self, other: Math) -> Math: ... ) -def test_upcast_args(): - # -0.1 + Int(x) -> -0.1 + Float(x) - EGraph() - - class Int(Expr): - def __init__(self, value: i64Like) -> None: ... - - def __add__(self, other: Int) -> Int: ... - - class Float(Expr): - def __init__(self, value: f64Like) -> None: ... - - def __add__(self, other: Float) -> Float: ... - - @classmethod - def from_int(cls, other: Int) -> Float: ... - - converter(i64, Int, Int) - converter(f64, Float, Float) - converter(Int, Float, Float.from_int) - - res: Expr = -0.1 + Int(10) # type: ignore[operator,assignment] - assert expr_parts(res) == expr_parts(Float(-0.1) + Float.from_int(Int(10))) - - res: Expr = Int(10) + -0.1 # type: ignore[operator,assignment] - assert expr_parts(res) == expr_parts(Float.from_int(Int(10)) + Float(-0.1)) - - def test_rewrite_upcasts(): class X(Expr): def __init__(self, value: i64Like) -> None: ...