-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathconversion.py
More file actions
274 lines (223 loc) · 10.1 KB
/
conversion.py
File metadata and controls
274 lines (223 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypeVar, cast
from .declarations import *
from .pretty import *
from .runtime import *
from .thunk import *
from .type_constraint_solver import TypeConstraintError
if TYPE_CHECKING:
from collections.abc import Generator
from .egraph import BaseExpr
from .type_constraint_solver import TypeConstraintSolver
__all__ = ["ConvertError", "convert", "converter", "get_type_args"]
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
# Global declerations to store all convertable types so we can query if they have certain methods or not
_CONVERSION_DECLS = Declarations.create()
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
# until we need them
_TO_PROCESS_DECLS: list[DeclerationsLike] = []
def retrieve_conversion_decls() -> Declarations:
_CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
_TO_PROCESS_DECLS.clear()
return _CONVERSION_DECLS
T = TypeVar("T")
V = TypeVar("V", bound="BaseExpr")
class ConvertError(Exception):
pass
def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
"""
Register a converter from some type to an egglog type.
"""
to_type_name = process_tp(to_type)
if not isinstance(to_type_name, JustTypeRef):
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
_register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
"""
Registers a converter from some type to an egglog type, if not already registered.
Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
Also, if registering A->B and there is already D->A, then D->B will be registered.
"""
if a == b:
return
if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
return
CONVERSIONS[(a, b)] = (cost, a_b)
for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
if _is_type_compatible(b, c):
_register_converter(
a, d, _ComposedConverter(a_b, c_d, c.args if isinstance(c, JustTypeRef) else ()), cost + other_cost
)
if _is_type_compatible(a, d):
_register_converter(
c, b, _ComposedConverter(c_d, a_b, a.args if isinstance(a, JustTypeRef) else ()), cost + other_cost
)
def _is_type_compatible(source: type | JustTypeRef, target: type | JustTypeRef) -> bool:
"""
Types must be equal or also support unbound to bound typevar like B -> B[C]
"""
if source == target:
return True
if isinstance(source, JustTypeRef) and isinstance(target, JustTypeRef) and source.args and not target.args:
return source.name == target.name
# TODO: Support case where B[T] where T is typevar is mapped to B[C]
return False
@dataclass
class _ComposedConverter:
"""
A converter which is composed of multiple converters.
_ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
We use the dataclass instead of the lambda to make it easier to debug.
"""
a_b: Callable[[Any], RuntimeExpr]
b_c: Callable[[Any], RuntimeExpr]
b_args: tuple[JustTypeRef, ...]
def __call__(self, x: Any) -> RuntimeExpr:
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
# when converting from A -> B
if self.b_args:
with with_type_args(self.b_args, retrieve_conversion_decls):
first_res = self.a_b(x)
else:
first_res = self.a_b(x)
return self.b_c(first_res)
def __str__(self) -> str:
return f"{self.b_c} ∘ {self.a_b}"
def convert(source: object, target: type[V]) -> V:
"""
Convert a source object to a target type.
"""
assert isinstance(target, RuntimeClass)
return cast("V", resolve_literal(target.__egg_tp__, source, target.__egg_decls_thunk__))
def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
"""
Convert a source object to the same type as the target.
"""
tp = target.__egg_typed_expr__.tp
return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
"""
Process a type before converting it, to add it to the global declerations and resolve to a ref.
"""
if isinstance(tp, RuntimeClass):
_TO_PROCESS_DECLS.append(tp)
egg_tp = tp.__egg_tp__
return egg_tp.to_just()
return tp
# def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
# """
# Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
# """
# decls = _retrieve_conversion_decls().copy()
# if isinstance(a, RuntimeExpr):
# decls |= a
# if isinstance(b, RuntimeExpr):
# decls |= b
# a_tp = _get_tp(a)
# b_tp = _get_tp(b)
# # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
# if not (
# (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
# or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
# ):
# raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
# a_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
# }
# b_converts_to = {
# to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
# }
# if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
# a_converts_to[a_tp] = 0
# if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
# b_converts_to[b_tp] = 0
# common = set(a_converts_to) & set(b_converts_to)
# if not common:
# raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
# return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
def identity(x: object) -> object:
return x
TYPE_ARGS = ContextVar[tuple[RuntimeClass, ...]]("TYPE_ARGS")
def get_type_args() -> tuple[type, ...]:
"""
Get the type args for the type being converted.
"""
return cast("tuple[type, ...]", TYPE_ARGS.get())
@contextmanager
def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declarations]) -> Generator[None, None, None]:
token = TYPE_ARGS.set(tuple(RuntimeClass(decls, a.to_var()) for a in args))
try:
yield
finally:
TYPE_ARGS.reset(token)
def resolve_literal(
tp: TypeOrVarRef,
arg: object,
decls: Callable[[], Declarations] = retrieve_conversion_decls,
tcs: TypeConstraintSolver | None = None,
cls_name: str | None = None,
) -> RuntimeExpr:
"""
Try to convert an object to a type, raising a ConvertError if it is not possible.
If the type has vars in it, they will be tried to be resolved into concrete vars based on the type constraint solver.
If it cannot be resolved, we assume that the value passed in will resolve it.
"""
arg_type = resolve_type(arg)
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
try:
tp_just = tp.to_just()
except TypeVarError:
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
# args first based on the existing type constraint solver
if tcs:
try:
tp_just = tcs.substitute_typevars(tp, cls_name)
# If we can't resolve the type var yet, then just assume it is the right value
except TypeConstraintError:
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
tp_just = arg.__egg_typed_expr__.tp
else:
# If this is a var, it has to be a runtime expession
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
return arg
if tcs:
tcs.infer_typevars(tp, tp_just, cls_name)
if arg_type == tp_just:
# If the type is an egg type, it has to be a runtime expr
assert isinstance(arg, RuntimeExpr)
return arg
# Try all parent types as well, if we are converting from a Python type
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
if (key := (arg_type_instance, tp_just)) in CONVERSIONS:
fn = CONVERSIONS[key][1]
break
# Try broadening if we have a convert to the general type instead of the specific one too, for generics
if tp_just.args and (key := (arg_type_instance, JustTypeRef(tp_just.name))) in CONVERSIONS:
fn = CONVERSIONS[key][1]
break
# if we didn't find any raise an error
else:
raise ConvertError(f"Cannot convert {arg_type} to {tp_just}")
with with_type_args(tp_just.args, decls):
return fn(arg)
def _debug_print_converers():
"""
Prints a mapping of all source types to target types that have a conversion function.
"""
source_to_targets = defaultdict(list)
for source, target in CONVERSIONS:
source_to_targets[source].append(target)
def resolve_type(x: object) -> JustTypeRef | type:
if isinstance(x, RuntimeExpr):
return x.__egg_typed_expr__.tp
tp = type(x)
# If this value has a custom metaclass, let's use that as our index instead of the type
if type(tp) is not type:
return type(tp)
return tp