Skip to content

Commit b14d4d2

Browse files
committed
Fix union and coalesce expressions not decoding to the correct type.
1 parent 8a956d6 commit b14d4d2

5 files changed

Lines changed: 178 additions & 12 deletions

File tree

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6170,6 +6170,45 @@ def resolve(
61706170
f"# type: ignore [assignment, misc, unused-ignore]"
61716171
)
61726172

6173+
if function.schemapath in {
6174+
SchemaPath('std', 'UNION'),
6175+
SchemaPath('std', 'IF'),
6176+
SchemaPath('std', '??'),
6177+
}:
6178+
# Special case for the UNION, IF and ?? operators
6179+
# Produce a union type instead of just taking the first
6180+
# valid type.
6181+
#
6182+
# See gel: edb.compiler.func.compile_operator
6183+
create_union = self.import_name(
6184+
BASE_IMPL, "create_optional_union"
6185+
)
6186+
6187+
tvars: list[str] = []
6188+
for param, path in sources:
6189+
if (
6190+
param.name in required_generic_params
6191+
or param.name in optional_generic_params
6192+
):
6193+
pn = param_vars[param.name]
6194+
tvar = f"__t_{pn}__"
6195+
6196+
resolve(pn, path, tvar)
6197+
tvars.append(tvar)
6198+
6199+
self.write(
6200+
f"{gtvar} = {tvars[0]} "
6201+
f"# type: ignore [assignment, misc, unused-ignore]"
6202+
)
6203+
for tvar in tvars[1:]:
6204+
self.write(
6205+
f"{gtvar} = {create_union}({gtvar}, {tvar}) "
6206+
f"# type: ignore ["
6207+
f"assignment, misc, unused-ignore]"
6208+
)
6209+
6210+
continue
6211+
61736212
# Try to infer generic type from required params first
61746213
for param, path in sources:
61756214
if param.name in required_generic_params:

gel/_internal/_qbmodel/_abstract/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
from ._methods import (
6969
BaseGelModel,
7070
BaseGelModelIntersection,
71+
BaseGelModelUnion,
72+
create_optional_union,
73+
create_union,
7174
)
7275

7376

@@ -138,6 +141,7 @@
138141
"ArrayMeta",
139142
"BaseGelModel",
140143
"BaseGelModelIntersection",
144+
"BaseGelModelUnion",
141145
"ComputedLinkSet",
142146
"ComputedLinkWithPropsSet",
143147
"ComputedMultiLinkDescriptor",
@@ -181,6 +185,8 @@
181185
"TupleMeta",
182186
"UUIDImpl",
183187
"copy_or_ref_lprops",
188+
"create_optional_union",
189+
"create_union",
184190
"empty_set_if_none",
185191
"field_descriptor",
186192
"get_base_scalars_backed_by_py_type",

gel/_internal/_qbmodel/_abstract/_methods.py

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919
from gel._internal import _qb
2020
from gel._internal._schemapath import (
21-
TypeNameIntersection,
2221
TypeNameExpr,
22+
TypeNameIntersection,
23+
TypeNameUnion,
2324
)
2425
from gel._internal import _type_expression
2526
from gel._internal._xmethod import classonlymethod
@@ -270,6 +271,17 @@ class BaseGelModelIntersectionBacklinks(
270271
rhs: ClassVar[type[AbstractGelObjectBacklinksModel]]
271272

272273

274+
class BaseGelModelUnion(
275+
BaseGelModel,
276+
_type_expression.Union,
277+
Generic[_T_Lhs, _T_Rhs],
278+
):
279+
__gel_type_class__: ClassVar[type]
280+
281+
lhs: ClassVar[type[AbstractGelModel]]
282+
rhs: ClassVar[type[AbstractGelModel]]
283+
284+
273285
T = TypeVar('T')
274286
U = TypeVar('U')
275287

@@ -318,6 +330,17 @@ def combine_dicts(
318330
return result
319331

320332

333+
def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]:
334+
if lhs == rhs:
335+
return (lhs,)
336+
elif issubclass(lhs, rhs):
337+
return (lhs, rhs)
338+
elif issubclass(rhs, lhs):
339+
return (rhs, lhs)
340+
else:
341+
return (lhs, rhs)
342+
343+
321344
_type_intersection_cache: weakref.WeakKeyDictionary[
322345
type[AbstractGelModel],
323346
weakref.WeakKeyDictionary[
@@ -430,17 +453,6 @@ def object(
430453
return result
431454

432455

433-
def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]:
434-
if lhs == rhs:
435-
return (lhs,)
436-
elif issubclass(lhs, rhs):
437-
return (lhs, rhs)
438-
elif issubclass(rhs, lhs):
439-
return (rhs, lhs)
440-
else:
441-
return (lhs, rhs)
442-
443-
444456
def create_intersection_backlinks(
445457
lhs_backlinks: type[AbstractGelObjectBacklinksModel],
446458
rhs_backlinks: type[AbstractGelObjectBacklinksModel],
@@ -500,3 +512,106 @@ def create_intersection_backlinks(
500512
)
501513

502514
return backlinks
515+
516+
517+
_type_union_cache: weakref.WeakKeyDictionary[
518+
type[AbstractGelModel],
519+
weakref.WeakKeyDictionary[
520+
type[AbstractGelModel],
521+
type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]],
522+
],
523+
] = weakref.WeakKeyDictionary()
524+
525+
526+
def create_optional_union(
527+
lhs: type[_T_Lhs] | None,
528+
rhs: type[_T_Rhs] | None,
529+
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None:
530+
if lhs is None:
531+
return rhs
532+
elif rhs is None:
533+
return lhs
534+
else:
535+
return create_union(lhs, rhs)
536+
537+
538+
def create_union(
539+
lhs: type[_T_Lhs],
540+
rhs: type[_T_Rhs],
541+
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]:
542+
"""Create a runtime union type which acts like a GelModel."""
543+
544+
if (lhs_entry := _type_union_cache.get(lhs)) and (
545+
rhs_entry := lhs_entry.get(rhs)
546+
):
547+
return rhs_entry # type: ignore[return-value]
548+
549+
# Combine pointer reflections from args
550+
ptr_reflections: dict[str, _qb.GelPointerReflection] = {
551+
p_name: p_refl
552+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
553+
if p_name in rhs.__gel_reflection__.pointers
554+
}
555+
556+
# Create type reflection for union type
557+
class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801
558+
expr_object_types: set[type[AbstractGelModel]] = getattr(
559+
lhs.__gel_reflection__, 'expr_object_types', {lhs}
560+
) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs})
561+
562+
type_name = TypeNameUnion(
563+
args=(
564+
lhs.__gel_reflection__.type_name,
565+
rhs.__gel_reflection__.type_name,
566+
)
567+
)
568+
569+
pointers = ptr_reflections
570+
571+
@classmethod
572+
def object(
573+
cls,
574+
) -> Any:
575+
raise NotImplementedError(
576+
"Type expressions schema objects are inaccessible"
577+
)
578+
579+
# Create the resulting union type
580+
result = type(
581+
f"({lhs.__name__} | {rhs.__name__})",
582+
(BaseGelModelUnion,),
583+
{
584+
'lhs': lhs,
585+
'rhs': rhs,
586+
'__gel_reflection__': __gel_reflection__,
587+
"__gel_proxied_dunders__": frozenset(
588+
{
589+
"__backlinks__",
590+
}
591+
),
592+
},
593+
)
594+
595+
# Generate field descriptors.
596+
descriptors: dict[str, ModelFieldDescriptor] = {
597+
p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__)
598+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
599+
if (
600+
hasattr(lhs, p_name)
601+
and (l_path_alias := getattr(lhs, p_name, None)) is not None
602+
and isinstance(l_path_alias, _qb.PathAlias)
603+
)
604+
if (
605+
hasattr(rhs, p_name)
606+
and (r_path_alias := getattr(rhs, p_name, None)) is not None
607+
and isinstance(r_path_alias, _qb.PathAlias)
608+
)
609+
}
610+
for p_name, descriptor in descriptors.items():
611+
setattr(result, p_name, descriptor)
612+
613+
if lhs not in _type_union_cache:
614+
_type_union_cache[lhs] = weakref.WeakKeyDictionary()
615+
_type_union_cache[lhs][rhs] = result
616+
617+
return result

gel/_internal/_typing_dispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
7070

7171
if issubclass(lhs, _type_expression.Intersection):
7272
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
73+
elif issubclass(lhs, _type_expression.Union):
74+
return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
7375

7476
if _typing_inspect.is_generic_alias(tp):
7577
origin = typing.get_origin(tp)

gel/models/pydantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
PyTypeScalarConstraint,
7777
RangeMeta,
7878
UUIDImpl,
79+
create_optional_union,
80+
create_union,
7981
empty_set_if_none,
8082
)
8183

@@ -215,6 +217,8 @@
215217
"classonlymethod",
216218
"computed_field",
217219
"construct_infix_op_chain",
220+
"create_optional_union",
221+
"create_union",
218222
"dispatch_overload",
219223
"empty_set_if_none",
220224
)

0 commit comments

Comments
 (0)