1919import ast
2020import builtins
2121import collections
22+ import dataclasses
2223import operator
2324import sys
2425import warnings
2526from functools import cached_property
2627from dataclasses import dataclass , field
27- from types import MethodDescriptorType , ModuleType
28+ from types import MethodDescriptorType , ModuleType , MethodType
2829
2930from IPython .utils .decorators import undoc
3031
@@ -353,7 +354,8 @@ class _DummyNamedTuple(NamedTuple):
353354EvaluationPolicyName = Literal ["forbidden" , "minimal" , "limited" , "unsafe" , "dangerous" ]
354355
355356
356- class EvaluationContext (NamedTuple ):
357+ @dataclass
358+ class EvaluationContext :
357359 #: Local namespace
358360 locals : dict
359361 #: Global namespace
@@ -366,7 +368,13 @@ class EvaluationContext(NamedTuple):
366368 #: Auto import method
367369 auto_import : Callable [list [str ], ModuleType ] | None = None
368370 #: Overrides for evaluation policy
369- policy_overrides : dict = {}
371+ policy_overrides : dict = field (default_factory = dict )
372+ #: Transient local namespace used to store mocks
373+ transient_locals : dict = field (default_factory = dict )
374+
375+ def replace (self , / , ** changes ):
376+ """Return a new copy of the context, with specified changes"""
377+ return dataclasses .replace (self , ** changes )
370378
371379
372380class _IdentitySubscript :
@@ -414,14 +422,14 @@ def guarded_eval(code: str, context: EvaluationContext):
414422 locals_ = locals_ .copy ()
415423 locals_ [SUBSCRIPT_MARKER ] = IDENTITY_SUBSCRIPT
416424 code = SUBSCRIPT_MARKER + "[" + code + "]"
417- context = EvaluationContext ( ** { ** context ._asdict (), ** { " locals" : locals_ }} )
425+ context = context .replace ( locals = locals_ )
418426
419427 if context .evaluation == "dangerous" :
420428 return eval (code , context .globals , context .locals )
421429
422- expression = ast .parse (code , mode = "eval " )
430+ node = ast .parse (code , mode = "exec " )
423431
424- return eval_node (expression , context )
432+ return eval_node (node , context )
425433
426434
427435BINARY_OP_DUNDERS : dict [type [ast .operator ], tuple [str ]] = {
@@ -524,6 +532,54 @@ def _validate_policy_overrides(
524532 return all_good
525533
526534
535+ def _handle_assign (node : ast .Assign , context : EvaluationContext ):
536+ value = eval_node (node .value , context )
537+ transient_locals = context .transient_locals
538+ for target in node .targets :
539+ if isinstance (target , (ast .Tuple , ast .List )):
540+ # Handle unpacking assignment
541+ values = list (value )
542+ targets = target .elts
543+ starred = [i for i , t in enumerate (targets ) if isinstance (t , ast .Starred )]
544+
545+ # Unified handling: treat no starred as starred at end
546+ star_or_last_idx = starred [0 ] if starred else len (targets )
547+
548+ # Before starred
549+ for i in range (star_or_last_idx ):
550+ transient_locals [targets [i ].id ] = values [i ]
551+
552+ # Starred if exists
553+ if starred :
554+ end = len (values ) - (len (targets ) - star_or_last_idx - 1 )
555+ transient_locals [targets [star_or_last_idx ].value .id ] = values [
556+ star_or_last_idx :end
557+ ]
558+
559+ # After starred
560+ for i in range (star_or_last_idx + 1 , len (targets )):
561+ transient_locals [targets [i ].id ] = values [
562+ len (values ) - (len (targets ) - i )
563+ ]
564+ else :
565+ transient_locals [target .id ] = value
566+ return None
567+
568+
569+ def _extract_args_and_kwargs (node : ast .Call , context : EvaluationContext ):
570+ args = [eval_node (arg , context ) for arg in node .args ]
571+ kwargs = {
572+ k : v
573+ for kw in node .keywords
574+ for k , v in (
575+ {kw .arg : eval_node (kw .value , context )}
576+ if kw .arg
577+ else eval_node (kw .value , context )
578+ ).items ()
579+ }
580+ return args , kwargs
581+
582+
527583def eval_node (node : Union [ast .AST , None ], context : EvaluationContext ):
528584 """Evaluate AST node in provided context.
529585
@@ -555,8 +611,48 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
555611
556612 if node is None :
557613 return None
614+ if isinstance (node , (ast .Interactive , ast .Module )):
615+ result = None
616+ for child_node in node .body :
617+ result = eval_node (child_node , context )
618+ return result
619+ if isinstance (node , ast .FunctionDef ):
620+ # we ignore body and only extract the return type
621+ # TODO:support decorators?
622+ def dummy_function (* args , ** kwargs ):
623+ pass
624+
625+ return_type = eval_node (node .returns , context = context )
626+ dummy_function .__annotations__ ["return" ] = return_type
627+ dummy_function .__name__ = node .name
628+ dummy_function .__node__ = node
629+ context .transient_locals [node .name ] = dummy_function
630+ return None
631+ if isinstance (node , ast .ClassDef ):
632+ # TODO support decorators?
633+ class_locals = {}
634+ class_context = context .replace (transient_locals = class_locals )
635+ for child_node in node .body :
636+ eval_node (child_node , class_context )
637+ bases = tuple ([eval_node (base , context ) for base in node .bases ])
638+ dummy_class = type (node .name , bases , class_locals )
639+ context .transient_locals [node .name ] = dummy_class
640+ return None
641+ if isinstance (node , ast .Assign ):
642+ return _handle_assign (node , context )
558643 if isinstance (node , ast .Expression ):
559644 return eval_node (node .body , context )
645+ if isinstance (node , ast .Expr ):
646+ return eval_node (node .value , context )
647+ if isinstance (node , ast .Pass ):
648+ return None
649+ if isinstance (node , ast .Import ):
650+ # TODO: populate transient_locals
651+ return None
652+ if isinstance (node , (ast .AugAssign , ast .Delete )):
653+ return None
654+ if isinstance (node , (ast .Global , ast .Nonlocal )):
655+ return None
560656 if isinstance (node , ast .BinOp ):
561657 left = eval_node (node .left , context )
562658 right = eval_node (node .right , context )
@@ -676,9 +772,9 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
676772 return eval_node (node .orelse , context )
677773 if isinstance (node , ast .Call ):
678774 func = eval_node (node .func , context )
679- if policy .can_call (func ) and not node . keywords :
680- args = [ eval_node ( arg , context ) for arg in node . args ]
681- return func (* args )
775+ if policy .can_call (func ):
776+ args , kwargs = _extract_args_and_kwargs ( node , context )
777+ return func (* args , ** kwargs )
682778 if isclass (func ):
683779 # this code path gets entered when calling class e.g. `MyClass()`
684780 # or `my_instance.__class__()` - in both cases `func` is `MyClass`.
@@ -717,21 +813,28 @@ def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext
717813 return NOT_EVALUATED
718814
719815
816+ def _eval_annotation (
817+ annotation : str ,
818+ context : EvaluationContext ,
819+ ):
820+ return (
821+ _eval_node_name (annotation , context )
822+ if isinstance (annotation , str )
823+ else annotation
824+ )
825+
826+
720827def _resolve_annotation (
721- annotation ,
828+ annotation : str ,
722829 sig : Signature ,
723830 func : Callable ,
724831 node : ast .Call ,
725832 context : EvaluationContext ,
726833):
727834 """Resolve annotation created by user with `typing` module and custom objects."""
728- annotation = (
729- _eval_node_name (annotation , context )
730- if isinstance (annotation , str )
731- else annotation
732- )
835+ annotation = _eval_annotation (annotation , context )
733836 origin = get_origin (annotation )
734- if annotation is Self and hasattr (func , "__self__" ):
837+ if annotation is Self and func and hasattr (func , "__self__" ):
735838 return func .__self__
736839 elif origin is Literal :
737840 type_args = get_args (annotation )
@@ -741,12 +844,30 @@ def _resolve_annotation(
741844 return ""
742845 elif annotation is AnyStr :
743846 index = None
744- for i , (key , value ) in enumerate (sig .parameters .items ()):
745- if value .annotation is AnyStr :
746- index = i
747- break
748- if index is not None and index < len (node .args ):
749- return eval_node (node .args [index ], context )
847+ if hasattr (func , "__node__" ):
848+ def_node = func .__node__
849+ for i , arg in enumerate (def_node .args .args ):
850+ if not arg .annotation :
851+ continue
852+ annotation = _eval_annotation (arg .annotation .id , context )
853+ if annotation is AnyStr :
854+ index = i
855+ break
856+ is_bound_method = (
857+ isinstance (func , MethodType ) and getattr (func , "__self__" ) is not None
858+ )
859+ if index and is_bound_method :
860+ index -= 1
861+ else :
862+ for i , (key , value ) in enumerate (sig .parameters .items ()):
863+ if value .annotation is AnyStr :
864+ index = i
865+ break
866+ if index is None :
867+ return None
868+ if index < 0 or index >= len (node .args ):
869+ return None
870+ return eval_node (node .args [index ], context )
750871 elif origin is TypeGuard :
751872 return False
752873 elif origin is Union :
@@ -779,6 +900,8 @@ def _resolve_annotation(
779900
780901def _eval_node_name (node_id : str , context : EvaluationContext ):
781902 policy = get_policy (context )
903+ if node_id in context .transient_locals :
904+ return context .transient_locals [node_id ]
782905 if policy .allow_locals_access and node_id in context .locals :
783906 return context .locals [node_id ]
784907 if policy .allow_globals_access and node_id in context .globals :
@@ -799,9 +922,8 @@ def _eval_node_name(node_id: str, context: EvaluationContext):
799922def _eval_or_create_duck (duck_type , node : ast .Call , context : EvaluationContext ):
800923 policy = get_policy (context )
801924 # if allow-listed builtin is on type annotation, instantiate it
802- if policy .can_call (duck_type ) and not node .keywords :
803- args = [eval_node (arg , context ) for arg in node .args ]
804- return duck_type (* args )
925+ if policy .can_call (duck_type ):
926+ return duck_type ()
805927 # if custom class is in type annotation, mock it
806928 return _create_duck_for_heap_type (duck_type )
807929
@@ -880,6 +1002,8 @@ def _list_methods(cls, source=None):
8801002 * _list_methods (str ),
8811003 tuple ,
8821004 * _list_methods (tuple ),
1005+ bool ,
1006+ * _list_methods (bool ),
8831007 * NUMERICS ,
8841008 * [method for numeric_cls in NUMERICS for method in _list_methods (numeric_cls )],
8851009 collections .deque ,
0 commit comments