Skip to content

Commit b5ceeea

Browse files
authored
Add cast function. (#944)
Allows users to cast between types: ```py assert client.get(std.str.cast(std.int64(1))) == "1" assert client.get(std.array[std.str].cast( std.array[std.int64]( [std.int64(1), std.int64(2), std.int64(3)] ) ) == ["1", "2", "3"] ``` Supports scalars, enums, array, tuple, range. Multiranges not supported, not sure we can even construct them in the qb right now.
1 parent ff76b6b commit b5ceeea

4 files changed

Lines changed: 304 additions & 23 deletions

File tree

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,14 @@ def _indirection_key(path: Indirection) -> tuple[str, ...]:
839839
SchemaPath("std", "multirange"),
840840
}
841841
)
842+
COLLECTION_TYPES = frozenset(
843+
{
844+
SchemaPath("std", "array"),
845+
SchemaPath("std", "tuple"),
846+
SchemaPath("std", "range"),
847+
SchemaPath("std", "multirange"),
848+
}
849+
)
842850

843851
PSEUDO_TYPES = frozenset(("anytuple", "anyobject", "anytype"))
844852

@@ -2352,37 +2360,39 @@ def write_generic_types(
23522360
t_anypt = self.declare_typevar("_T_anypoint", bound=anypoint)
23532361
self.write(f'_Tt = {typevartup}("_Tt")')
23542362

2363+
# Order the bases with more specific types first so that
2364+
# __gel_reflection__ is resolved correctly.
23552365
generics = {
23562366
SchemaPath("std", "anytype"): [
23572367
geltype,
23582368
],
23592369
SchemaPath("std", "anyobject"): [
2360-
"anytype",
23612370
gelmodel,
2371+
"anytype",
23622372
],
23632373
SchemaPath("std", "anytuple"): [
2364-
"anytype",
23652374
anytuple,
2375+
"anytype",
23662376
],
23672377
SchemaPath("std", "anynamedtuple"): [
2368-
"anytuple",
23692378
anynamedtuple,
2379+
"anytuple",
23702380
],
23712381
SchemaPath("std", "tuple"): [
2372-
"anytuple",
23732382
f"{tup}[{unpack}[_Tt]]",
2383+
"anytuple",
23742384
],
23752385
SchemaPath("std", "array"): [
2376-
"anytype",
23772386
f"{arr}[{t_anytype}]",
2387+
"anytype",
23782388
],
23792389
SchemaPath("std", "range"): [
2380-
"anytype",
23812390
f"{rang}[{t_anypt}]",
2391+
"anytype",
23822392
],
23832393
SchemaPath("std", "multirange"): [
2384-
"anytype",
23852394
f"{mrang}[{t_anypt}]",
2395+
"anytype",
23862396
],
23872397
}
23882398

@@ -2515,7 +2525,17 @@ def write_generic_types(
25152525
suggested_module_alias=rel_import.module_alias,
25162526
)
25172527
type_ident = ident(gt.name)
2518-
self.write(f"{type_ident} = {imported_name}")
2528+
2529+
if gt in COLLECTION_TYPES:
2530+
# Import collection classes directly so mypy sees them as
2531+
# generic classes, not type aliases
2532+
self.write(
2533+
f"from {rel_import.module} "
2534+
f"import {type_ident} as {type_ident}"
2535+
)
2536+
else:
2537+
self.write(f"{type_ident} = {imported_name}")
2538+
25192539
self.export(type_ident)
25202540

25212541
def _write_enum_scalar_type(
@@ -2533,9 +2553,43 @@ def _write_enum_scalar_type(
25332553
self.write_description(stype)
25342554
for value in stype.enum_values:
25352555
self.write(f"{ident(value)} = {value!r}")
2556+
2557+
self.write()
25362558
self.write_type_reflection(stype)
2559+
2560+
self.write()
2561+
self._write_enum_scalar_cast(stype)
2562+
25372563
self.write_section_break()
25382564

2565+
def _write_enum_scalar_cast(
2566+
self,
2567+
stype: reflection.ScalarType,
2568+
) -> None:
2569+
expr_compat = self.import_name(BASE_IMPL, "ExprCompatible")
2570+
2571+
type_ = self.import_name("builtins", "type")
2572+
self_ = self.import_name("typing_extensions", "Self")
2573+
type_self = f"{type_}[{self_}]"
2574+
2575+
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
2576+
cast_op = self.import_name(BASE_IMPL, "CastOp")
2577+
2578+
with self._classmethod_def(
2579+
"cast",
2580+
[f"expr: {expr_compat}"],
2581+
type_self,
2582+
):
2583+
self.write(f"return {aexpr}( # type: ignore [return-value]")
2584+
with self.indented():
2585+
self.write("cls,")
2586+
self.write(f"{cast_op}(")
2587+
with self.indented():
2588+
self.write("expr=expr,")
2589+
self.write("type_=cls.__gel_reflection__.type_name,")
2590+
self.write(")")
2591+
self.write(")")
2592+
25392593
def _write_scalar_type(
25402594
self,
25412595
stype: reflection.ScalarType,
@@ -2624,6 +2678,8 @@ def _write_regular_scalar_type(
26242678

26252679
self.export("anyenum")
26262680

2681+
is_generic = type_name in GENERIC_TYPES
2682+
26272683
if not runtime_parents:
26282684
typecheck_parents = [self.get_type(self._types_by_name["anytype"])]
26292685
runtime_parents = typecheck_parents
@@ -2725,6 +2781,9 @@ def _write_regular_scalar_type(
27252781
):
27262782
self.write("...")
27272783

2784+
if not is_generic:
2785+
self._write_regular_scalar_cast(stype, signature_only=True)
2786+
27282787
self.write()
27292788

27302789
with self.not_type_checking():
@@ -2739,8 +2798,55 @@ def _write_regular_scalar_type(
27392798
self.write()
27402799
self.write_type_reflection(stype)
27412800

2801+
if not is_generic:
2802+
self.write()
2803+
self._write_regular_scalar_cast(
2804+
stype, signature_only=False
2805+
)
2806+
27422807
self.write_section_break()
27432808

2809+
def _write_regular_scalar_cast(
2810+
self,
2811+
stype: reflection.ScalarType,
2812+
*,
2813+
signature_only: bool,
2814+
) -> None:
2815+
expr_compat = self.import_name(BASE_IMPL, "ExprCompatible")
2816+
2817+
type_ = self.import_name("builtins", "type")
2818+
self_ = self.import_name("typing_extensions", "Self")
2819+
type_self = f"{type_}[{self_}]"
2820+
2821+
if signature_only:
2822+
self.write()
2823+
with self._classmethod_def(
2824+
"cast",
2825+
[f"expr: {expr_compat}"],
2826+
type_self,
2827+
):
2828+
self.write("...")
2829+
2830+
else:
2831+
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
2832+
cast_op = self.import_name(BASE_IMPL, "CastOp")
2833+
2834+
self.write()
2835+
with self._classmethod_def(
2836+
"cast",
2837+
[f"expr: {expr_compat}"],
2838+
type_self,
2839+
):
2840+
self.write(f"return {aexpr}( # type: ignore [return-value]")
2841+
with self.indented():
2842+
self.write("cls,")
2843+
self.write(f"{cast_op}(")
2844+
with self.indented():
2845+
self.write("expr=expr,")
2846+
self.write("type_=cls.__gel_reflection__.type_name,")
2847+
self.write(")")
2848+
self.write(")")
2849+
27442850
def render_callable_return_type(
27452851
self,
27462852
tp: reflection.Type,

gel/_internal/_qbmodel/_abstract/_primitive.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,22 @@ def _reconstruct_from_pickle(
269269
def __gel_get_py_type__(cls) -> type:
270270
return list
271271

272+
def __edgeql_literal__(self) -> _qb.Literal:
273+
return _qb.Literal(
274+
type_=type(self).__gel_reflection__.type_name,
275+
val=self,
276+
)
277+
278+
@classmethod
279+
def cast(cls, expr: _qb.ExprCompatible) -> type[Array[_T]]:
280+
return _qb.AnnotatedExpr( # type: ignore [return-value]
281+
cls,
282+
_qb.CastOp(
283+
expr=expr,
284+
type_=cls.__gel_reflection__.type_name,
285+
),
286+
)
287+
272288

273289
_Ts = TypeVarTuple("_Ts")
274290

@@ -335,6 +351,27 @@ class __gel_reflection__(GelPrimitiveType.__gel_reflection__): # noqa: N801
335351
def __gel_get_py_type__(cls) -> type:
336352
return tuple
337353

354+
def __edgeql_literal__(self) -> _qb.Literal:
355+
return _qb.Literal(
356+
type_=type(self).__gel_reflection__.type_name,
357+
val=self,
358+
)
359+
360+
@classmethod
361+
def cast(cls, expr: _qb.ExprCompatible) -> Self:
362+
return _qb.AnnotatedExpr( # type: ignore [return-value]
363+
cls,
364+
_qb.CastOp(
365+
expr=expr,
366+
type_=cls.__gel_reflection__.type_name,
367+
),
368+
)
369+
370+
if TYPE_CHECKING:
371+
372+
def __new__(cls, args: tuple[Any, ...]) -> Self: ...
373+
def __init__(self, args: tuple[Unpack[_Ts]]): ...
374+
338375

339376
if TYPE_CHECKING:
340377

@@ -374,6 +411,16 @@ class __gel_reflection__(GelPrimitiveType.__gel_reflection__): # noqa: N801
374411

375412
return __gel_reflection__
376413

414+
@classmethod
415+
def cast(cls, expr: _qb.ExprCompatible) -> type[Range[_T]]:
416+
return _qb.AnnotatedExpr( # type: ignore [return-value]
417+
cls,
418+
_qb.CastOp(
419+
expr=expr,
420+
type_=cls.__gel_reflection__.type_name,
421+
),
422+
)
423+
377424

378425
if TYPE_CHECKING:
379426

@@ -459,6 +506,9 @@ def get_py_type_from_gel_type(tp: type[GelType]) -> Any:
459506
case t if issubclass(t, PyTypeScalar):
460507
return t.__gel_py_type__
461508

509+
case t if issubclass(t, AnyEnum):
510+
return t
511+
462512
case t:
463513
raise NotImplementedError(
464514
f"get_py_type({t.__name__}) is not implemented"

gel/models/pydantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OPERAND_IS_ALIAS,
2323
AnnotatedExpr,
2424
BaseAlias,
25+
CastOp,
2526
EmptyDirection,
2627
Direction,
2728
GelLinkMetadata,
@@ -116,6 +117,7 @@
116117
"ArrayMeta",
117118
"BaseAlias",
118119
"Cardinality",
120+
"CastOp",
119121
"ComputedLink",
120122
"ComputedLinkWithProps",
121123
"ComputedMultiLink",

0 commit comments

Comments
 (0)