Skip to content

Commit 16e3edf

Browse files
committed
fix-union-types
1 parent d4cd689 commit 16e3edf

1 file changed

Lines changed: 42 additions & 3 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from IPython.utils.decorators import undoc
3232

33-
33+
import types
3434
from typing import Self, LiteralString
3535

3636
if sys.version_info < (3, 12):
@@ -564,6 +564,25 @@ def _validate_policy_overrides(
564564
return all_good
565565

566566

567+
def is_type_annotation(obj) -> bool:
568+
"""
569+
Returns True if obj is a type annotation, False otherwise.
570+
"""
571+
if isinstance(obj, type):
572+
return True
573+
if hasattr(typing, "get_origin"):
574+
if typing.get_origin(obj) is not None:
575+
return True
576+
if hasattr(obj, "__module__") and obj.__module__ == "typing":
577+
return True
578+
if isinstance(obj, types.GenericAlias):
579+
return True
580+
if type(obj).__name__ in ("_GenericAlias", "_SpecialForm", "_UnionGenericAlias"):
581+
return True
582+
583+
return False
584+
585+
567586
def _handle_assign(node: ast.Assign, context: EvaluationContext):
568587
value = eval_node(node.value, context)
569588
transient_locals = context.transient_locals
@@ -904,11 +923,19 @@ def dummy_function(*args, **kwargs):
904923
return _handle_assign(node, context)
905924
if isinstance(node, ast.AnnAssign):
906925
if node.simple:
907-
value = _resolve_annotation(eval_node(node.annotation, context), context)
926+
annotation_result = eval_node(node.annotation, context)
927+
if is_type_annotation(annotation_result):
928+
value = _resolve_annotation(annotation_result, context)
929+
else:
930+
value = annotation_result
908931
context.transient_locals[node.target.id] = value
909932
# Handle non-simple annotated assignments only for self.x: type = value
910933
if _is_instance_attribute_assignment(node.target, context):
911-
value = _resolve_annotation(eval_node(node.annotation, context), context)
934+
annotation_result = eval_node(node.annotation, context)
935+
if is_type_annotation(annotation_result):
936+
value = _resolve_annotation(annotation_result, context)
937+
else:
938+
value = annotation_result
912939
context.class_transients[node.target.attr] = value
913940
return None
914941
if isinstance(node, ast.Expression):
@@ -927,6 +954,18 @@ def dummy_function(*args, **kwargs):
927954
if isinstance(node, ast.BinOp):
928955
left = eval_node(node.left, context)
929956
right = eval_node(node.right, context)
957+
if is_type_annotation(left) and is_type_annotation(right):
958+
left_duck = (
959+
_Duck(dict.fromkeys(dir(left)))
960+
if policy.can_call(left.__dir__)
961+
else _Duck()
962+
)
963+
right_duck = (
964+
_Duck(dict.fromkeys(dir(right)))
965+
if policy.can_call(right.__dir__)
966+
else _Duck()
967+
)
968+
return _merge_values([left_duck, right_duck], policy=get_policy(context))
930969
dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
931970
if dunders:
932971
if policy.can_operate(dunders, left, right):

0 commit comments

Comments
 (0)