Skip to content

Commit 0758e85

Browse files
committed
Fix union and coalesce expressions not decoding to the correct type.
1 parent 0553a84 commit 0758e85

5 files changed

Lines changed: 150 additions & 1 deletion

File tree

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5795,6 +5795,43 @@ def resolve(
57955795
f"# type: ignore [assignment, misc, unused-ignore]"
57965796
)
57975797

5798+
if function.schemapath in {
5799+
SchemaPath('std', 'UNION'),
5800+
SchemaPath('std', 'IF'),
5801+
SchemaPath('std', '??'),
5802+
}:
5803+
# Special case for the UNION, IF and ?? operators
5804+
# Produce a union type instead of just taking the first
5805+
# valid type.
5806+
#
5807+
# See gel: compile_operator
5808+
create_union = self.import_name(
5809+
BASE_IMPL, "create_union"
5810+
)
5811+
5812+
tvars: list[str] = []
5813+
for param, path in sources:
5814+
if param.name not in required_generic_params:
5815+
continue
5816+
5817+
pn = param_vars[param.name]
5818+
tvar = f"__t_{pn}__"
5819+
5820+
resolve(pn, path, tvar)
5821+
tvars.append(tvar)
5822+
5823+
self.write(
5824+
f"{gtvar} = {tvars[0]} "
5825+
f"# type: ignore [assignment, misc, unused-ignore]"
5826+
)
5827+
for tvar in tvars[1:]:
5828+
self.write(
5829+
f"{gtvar} = {create_union}({gtvar}, {tvar}) "
5830+
f"# type: ignore [assignment, misc, unused-ignore]"
5831+
)
5832+
5833+
continue
5834+
57985835
# Try to infer generic type from required params first
57995836
for param, path in sources:
58005837
if param.name in required_generic_params:

gel/_internal/_qbmodel/_abstract/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565
from ._methods import (
6666
BaseGelModel,
6767
BaseGelModelIntersection,
68+
BaseGelModelUnion,
69+
create_union,
6870
)
6971

7072

@@ -132,6 +134,7 @@
132134
"ArrayMeta",
133135
"BaseGelModel",
134136
"BaseGelModelIntersection",
137+
"BaseGelModelUnion",
135138
"ComputedLinkSet",
136139
"ComputedLinkWithPropsSet",
137140
"ComputedMultiLinkDescriptor",
@@ -174,6 +177,7 @@
174177
"TupleMeta",
175178
"UUIDImpl",
176179
"copy_or_ref_lprops",
180+
"create_union",
177181
"empty_set_if_none",
178182
"field_descriptor",
179183
"get_base_scalars_backed_by_py_type",

gel/_internal/_qbmodel/_abstract/_methods.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from gel._internal import _qb
2020
from gel._internal._schemapath import (
2121
TypeNameIntersection,
22+
TypeNameUnion,
2223
)
2324
from gel._internal._xmethod import classonlymethod
2425

@@ -254,6 +255,16 @@ class BaseGelModelIntersection(
254255
rhs: ClassVar[type[AbstractGelModel]]
255256

256257

258+
class BaseGelModelUnion(
259+
BaseGelModel,
260+
Generic[_T_Lhs, _T_Rhs],
261+
):
262+
__gel_type_class__: ClassVar[type]
263+
264+
lhs: ClassVar[type[AbstractGelModel]]
265+
rhs: ClassVar[type[AbstractGelModel]]
266+
267+
257268
T = TypeVar('T')
258269
U = TypeVar('U')
259270

@@ -427,3 +438,93 @@ def process_path_alias(
427438
_type_intersection_cache[lhs][rhs] = result
428439

429440
return result
441+
442+
443+
_type_union_cache: weakref.WeakKeyDictionary[
444+
type[AbstractGelModel],
445+
weakref.WeakKeyDictionary[
446+
type[AbstractGelModel],
447+
type[
448+
BaseGelModelUnion[type[AbstractGelModel], type[AbstractGelModel]]
449+
],
450+
],
451+
] = weakref.WeakKeyDictionary()
452+
453+
454+
def create_union(
455+
lhs: _T_Lhs,
456+
rhs: _T_Rhs,
457+
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]:
458+
"""Create a runtime union type which acts like a GelModel."""
459+
460+
if (lhs_entry := _type_union_cache.get(lhs)) and (
461+
rhs_entry := lhs_entry.get(rhs)
462+
):
463+
return rhs_entry # type: ignore[return-value]
464+
465+
# Combine pointer reflections from args
466+
ptr_reflections: dict[str, _qb.GelPointerReflection] = {
467+
p_name: p_refl
468+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
469+
if p_name in rhs.__gel_reflection__.pointers
470+
}
471+
472+
# Create type reflection for union type
473+
class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801
474+
expr_object_types: set[type[AbstractGelModel]] = getattr(
475+
lhs.__gel_reflection__, 'expr_object_types', {lhs}
476+
) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs})
477+
478+
type_name = TypeNameUnion(
479+
args=(
480+
lhs.__gel_reflection__.type_name,
481+
rhs.__gel_reflection__.type_name,
482+
)
483+
)
484+
485+
pointers = ptr_reflections
486+
487+
@classmethod
488+
def object(
489+
cls,
490+
) -> Any:
491+
raise NotImplementedError(
492+
"Type expressions schema objects are inaccessible"
493+
)
494+
495+
result = type(
496+
f"({lhs.__name__} | {rhs.__name__})",
497+
(BaseGelModelUnion,),
498+
{
499+
'lhs': lhs,
500+
'rhs': rhs,
501+
'__gel_reflection__': __gel_reflection__,
502+
},
503+
)
504+
505+
# Generate path aliases for pointers.
506+
#
507+
# These are used to generate the appropriate path prefix when getting
508+
# pointers in shapes.
509+
path_aliases: dict[str, _qb.PathAlias] = {
510+
p_name: l_path_alias
511+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
512+
if (
513+
hasattr(lhs, p_name)
514+
and (l_path_alias := getattr(lhs, p_name, None)) is not None
515+
and isinstance(l_path_alias, _qb.PathAlias)
516+
)
517+
if (
518+
hasattr(rhs, p_name)
519+
and (r_path_alias := getattr(rhs, p_name, None)) is not None
520+
and isinstance(r_path_alias, _qb.PathAlias)
521+
)
522+
}
523+
for p_name, path_alias in path_aliases.items():
524+
setattr(result, p_name, path_alias)
525+
526+
if lhs not in _type_union_cache:
527+
_type_union_cache[lhs] = weakref.WeakKeyDictionary()
528+
_type_union_cache[lhs][rhs] = result
529+
530+
return result

gel/_internal/_typing_dispatch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
from gel._internal import _typing_inspect
3939
from gel._internal import _typing_parametric
4040
from gel._internal._utils import type_repr
41-
from gel._internal._qbmodel._abstract._methods import BaseGelModelIntersection
41+
from gel._internal._qbmodel._abstract._methods import (
42+
BaseGelModelIntersection,
43+
BaseGelModelUnion,
44+
)
4245

4346
_P = ParamSpec("_P")
4447
_R_co = TypeVar("_R_co", covariant=True)
@@ -70,6 +73,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
7073

7174
if issubclass(lhs, BaseGelModelIntersection):
7275
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
76+
elif issubclass(lhs, BaseGelModelUnion):
77+
return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
7378

7479
if _typing_inspect.is_generic_alias(tp):
7580
origin = typing.get_origin(tp)

gel/models/pydantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
PyTypeScalarConstraint,
6969
RangeMeta,
7070
UUIDImpl,
71+
create_union,
7172
empty_set_if_none,
7273
)
7374

@@ -195,6 +196,7 @@
195196
"classonlymethod",
196197
"computed_field",
197198
"construct_infix_op_chain",
199+
"create_union",
198200
"dispatch_overload",
199201
"empty_set_if_none",
200202
)

0 commit comments

Comments
 (0)