|
18 | 18 |
|
19 | 19 | from gel._internal import _qb |
20 | 20 | from gel._internal._schemapath import ( |
21 | | - TypeNameIntersection, |
22 | 21 | TypeNameExpr, |
| 22 | + TypeNameIntersection, |
| 23 | + TypeNameUnion, |
23 | 24 | ) |
24 | 25 | from gel._internal import _type_expression |
25 | 26 | from gel._internal._xmethod import classonlymethod |
@@ -270,6 +271,17 @@ class BaseGelModelIntersectionBacklinks( |
270 | 271 | rhs: ClassVar[type[AbstractGelObjectBacklinksModel]] |
271 | 272 |
|
272 | 273 |
|
| 274 | +class BaseGelModelUnion( |
| 275 | + BaseGelModel, |
| 276 | + _type_expression.Union, |
| 277 | + Generic[_T_Lhs, _T_Rhs], |
| 278 | +): |
| 279 | + __gel_type_class__: ClassVar[type] |
| 280 | + |
| 281 | + lhs: ClassVar[type[AbstractGelModel]] |
| 282 | + rhs: ClassVar[type[AbstractGelModel]] |
| 283 | + |
| 284 | + |
273 | 285 | T = TypeVar('T') |
274 | 286 | U = TypeVar('U') |
275 | 287 |
|
@@ -318,6 +330,17 @@ def combine_dicts( |
318 | 330 | return result |
319 | 331 |
|
320 | 332 |
|
| 333 | +def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: |
| 334 | + if lhs == rhs: |
| 335 | + return (lhs,) |
| 336 | + elif issubclass(lhs, rhs): |
| 337 | + return (lhs, rhs) |
| 338 | + elif issubclass(rhs, lhs): |
| 339 | + return (rhs, lhs) |
| 340 | + else: |
| 341 | + return (lhs, rhs) |
| 342 | + |
| 343 | + |
321 | 344 | _type_intersection_cache: weakref.WeakKeyDictionary[ |
322 | 345 | type[AbstractGelModel], |
323 | 346 | weakref.WeakKeyDictionary[ |
@@ -430,17 +453,6 @@ def object( |
430 | 453 | return result |
431 | 454 |
|
432 | 455 |
|
433 | | -def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: |
434 | | - if lhs == rhs: |
435 | | - return (lhs,) |
436 | | - elif issubclass(lhs, rhs): |
437 | | - return (lhs, rhs) |
438 | | - elif issubclass(rhs, lhs): |
439 | | - return (rhs, lhs) |
440 | | - else: |
441 | | - return (lhs, rhs) |
442 | | - |
443 | | - |
444 | 456 | def create_intersection_backlinks( |
445 | 457 | lhs_backlinks: type[AbstractGelObjectBacklinksModel], |
446 | 458 | rhs_backlinks: type[AbstractGelObjectBacklinksModel], |
@@ -500,3 +512,106 @@ def create_intersection_backlinks( |
500 | 512 | ) |
501 | 513 |
|
502 | 514 | return backlinks |
| 515 | + |
| 516 | + |
| 517 | +_type_union_cache: weakref.WeakKeyDictionary[ |
| 518 | + type[AbstractGelModel], |
| 519 | + weakref.WeakKeyDictionary[ |
| 520 | + type[AbstractGelModel], |
| 521 | + type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]], |
| 522 | + ], |
| 523 | +] = weakref.WeakKeyDictionary() |
| 524 | + |
| 525 | + |
| 526 | +def create_optional_union( |
| 527 | + lhs: type[_T_Lhs] | None, |
| 528 | + rhs: type[_T_Rhs] | None, |
| 529 | +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None: |
| 530 | + if lhs is None: |
| 531 | + return rhs |
| 532 | + elif rhs is None: |
| 533 | + return lhs |
| 534 | + else: |
| 535 | + return create_union(lhs, rhs) |
| 536 | + |
| 537 | + |
| 538 | +def create_union( |
| 539 | + lhs: type[_T_Lhs], |
| 540 | + rhs: type[_T_Rhs], |
| 541 | +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]: |
| 542 | + """Create a runtime union type which acts like a GelModel.""" |
| 543 | + |
| 544 | + if (lhs_entry := _type_union_cache.get(lhs)) and ( |
| 545 | + rhs_entry := lhs_entry.get(rhs) |
| 546 | + ): |
| 547 | + return rhs_entry # type: ignore[return-value] |
| 548 | + |
| 549 | + # Combine pointer reflections from args |
| 550 | + ptr_reflections: dict[str, _qb.GelPointerReflection] = { |
| 551 | + p_name: p_refl |
| 552 | + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() |
| 553 | + if p_name in rhs.__gel_reflection__.pointers |
| 554 | + } |
| 555 | + |
| 556 | + # Create type reflection for union type |
| 557 | + class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801 |
| 558 | + expr_object_types: set[type[AbstractGelModel]] = getattr( |
| 559 | + lhs.__gel_reflection__, 'expr_object_types', {lhs} |
| 560 | + ) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs}) |
| 561 | + |
| 562 | + type_name = TypeNameUnion( |
| 563 | + args=( |
| 564 | + lhs.__gel_reflection__.type_name, |
| 565 | + rhs.__gel_reflection__.type_name, |
| 566 | + ) |
| 567 | + ) |
| 568 | + |
| 569 | + pointers = ptr_reflections |
| 570 | + |
| 571 | + @classmethod |
| 572 | + def object( |
| 573 | + cls, |
| 574 | + ) -> Any: |
| 575 | + raise NotImplementedError( |
| 576 | + "Type expressions schema objects are inaccessible" |
| 577 | + ) |
| 578 | + |
| 579 | + # Create the resulting union type |
| 580 | + result = type( |
| 581 | + f"({lhs.__name__} | {rhs.__name__})", |
| 582 | + (BaseGelModelUnion,), |
| 583 | + { |
| 584 | + 'lhs': lhs, |
| 585 | + 'rhs': rhs, |
| 586 | + '__gel_reflection__': __gel_reflection__, |
| 587 | + "__gel_proxied_dunders__": frozenset( |
| 588 | + { |
| 589 | + "__backlinks__", |
| 590 | + } |
| 591 | + ), |
| 592 | + }, |
| 593 | + ) |
| 594 | + |
| 595 | + # Generate field descriptors. |
| 596 | + descriptors: dict[str, ModelFieldDescriptor] = { |
| 597 | + p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__) |
| 598 | + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() |
| 599 | + if ( |
| 600 | + hasattr(lhs, p_name) |
| 601 | + and (l_path_alias := getattr(lhs, p_name, None)) is not None |
| 602 | + and isinstance(l_path_alias, _qb.PathAlias) |
| 603 | + ) |
| 604 | + if ( |
| 605 | + hasattr(rhs, p_name) |
| 606 | + and (r_path_alias := getattr(rhs, p_name, None)) is not None |
| 607 | + and isinstance(r_path_alias, _qb.PathAlias) |
| 608 | + ) |
| 609 | + } |
| 610 | + for p_name, descriptor in descriptors.items(): |
| 611 | + setattr(result, p_name, descriptor) |
| 612 | + |
| 613 | + if lhs not in _type_union_cache: |
| 614 | + _type_union_cache[lhs] = weakref.WeakKeyDictionary() |
| 615 | + _type_union_cache[lhs][rhs] = result |
| 616 | + |
| 617 | + return result |
0 commit comments