3030
3131from IPython .utils .decorators import undoc
3232
33-
33+ import types
3434from typing import Self , LiteralString
3535
3636if sys .version_info < (3 , 12 ):
@@ -564,6 +564,25 @@ def _validate_policy_overrides(
564564 return all_good
565565
566566
567+ def is_type_annotation (obj ) -> bool :
568+ """
569+ Returns True if obj is a type annotation, False otherwise.
570+ """
571+ if isinstance (obj , type ):
572+ return True
573+ if hasattr (typing , "get_origin" ):
574+ if typing .get_origin (obj ) is not None :
575+ return True
576+ if hasattr (obj , "__module__" ) and obj .__module__ == "typing" :
577+ return True
578+ if isinstance (obj , types .GenericAlias ):
579+ return True
580+ if type (obj ).__name__ in ("_GenericAlias" , "_SpecialForm" , "_UnionGenericAlias" ):
581+ return True
582+
583+ return False
584+
585+
567586def _handle_assign (node : ast .Assign , context : EvaluationContext ):
568587 value = eval_node (node .value , context )
569588 transient_locals = context .transient_locals
@@ -904,11 +923,19 @@ def dummy_function(*args, **kwargs):
904923 return _handle_assign (node , context )
905924 if isinstance (node , ast .AnnAssign ):
906925 if node .simple :
907- value = _resolve_annotation (eval_node (node .annotation , context ), context )
926+ annotation_result = eval_node (node .annotation , context )
927+ if is_type_annotation (annotation_result ):
928+ value = _resolve_annotation (annotation_result , context )
929+ else :
930+ value = annotation_result
908931 context .transient_locals [node .target .id ] = value
909932 # Handle non-simple annotated assignments only for self.x: type = value
910933 if _is_instance_attribute_assignment (node .target , context ):
911- value = _resolve_annotation (eval_node (node .annotation , context ), context )
934+ annotation_result = eval_node (node .annotation , context )
935+ if is_type_annotation (annotation_result ):
936+ value = _resolve_annotation (annotation_result , context )
937+ else :
938+ value = annotation_result
912939 context .class_transients [node .target .attr ] = value
913940 return None
914941 if isinstance (node , ast .Expression ):
@@ -927,6 +954,18 @@ def dummy_function(*args, **kwargs):
927954 if isinstance (node , ast .BinOp ):
928955 left = eval_node (node .left , context )
929956 right = eval_node (node .right , context )
957+ if is_type_annotation (left ) and is_type_annotation (right ):
958+ left_duck = (
959+ _Duck (dict .fromkeys (dir (left )))
960+ if policy .can_call (left .__dir__ )
961+ else _Duck ()
962+ )
963+ right_duck = (
964+ _Duck (dict .fromkeys (dir (right )))
965+ if policy .can_call (right .__dir__ )
966+ else _Duck ()
967+ )
968+ return _merge_values ([left_duck , right_duck ], policy = get_policy (context ))
930969 dunders = _find_dunder (node .op , BINARY_OP_DUNDERS )
931970 if dunders :
932971 if policy .can_operate (dunders , left , right ):
0 commit comments