Skip to content

Commit aebe495

Browse files
committed
store-temp-locals
1 parent 71be675 commit aebe495

1 file changed

Lines changed: 51 additions & 15 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,12 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
637637
if node is None:
638638
return None
639639
if isinstance(node, (ast.Interactive, ast.Module)):
640+
context_copy = context.replace(locals=context.locals.copy())
641+
module_vars = _extract_variables_from_module(node, context_copy)
642+
context_copy.locals.update(module_vars)
640643
result = None
641644
for child_node in node.body:
642-
result = eval_node(child_node, context)
645+
result = eval_node(child_node, context_copy)
643646
return result
644647
if isinstance(node, ast.FunctionDef):
645648
# we ignore body and only extract the return type
@@ -677,9 +680,7 @@ def dummy_function(*args, **kwargs):
677680
class_context = context.replace(transient_locals=class_locals)
678681
for child_node in node.body:
679682
eval_node(child_node, class_context)
680-
# extract self.attribute assignments
681683
init_attributes = _extract_init_attributes(node, class_context)
682-
# Merge init attributes into class_locals
683684
class_locals.update(init_attributes)
684685
bases = tuple([eval_node(base, context) for base in node.bases])
685686
dummy_class = type(node.name, bases, class_locals)
@@ -878,14 +879,6 @@ def _extract_init_attributes(class_node: ast.ClassDef, context: EvaluationContex
878879
if not init_method:
879880
return attributes
880881

881-
temp_context = EvaluationContext(
882-
globals=context.globals,
883-
locals=context.locals,
884-
evaluation=context.evaluation,
885-
in_subscript=context.in_subscript,
886-
transient_locals={},
887-
)
888-
889882
for stmt in init_method.body:
890883
# Handle regular assignments: self.attr = value
891884
if isinstance(stmt, ast.Assign):
@@ -895,10 +888,10 @@ def _extract_init_attributes(class_node: ast.ClassDef, context: EvaluationContex
895888
attr_name = target.attr
896889
try:
897890
# Evaluate the assigned value
898-
value = eval_node(stmt.value, temp_context)
891+
value = eval_node(stmt.value, context)
899892
if value is not None and value is not NOT_EVALUATED:
900893
attributes[attr_name] = value
901-
except Exception:
894+
except Exception as e:
902895
# Skip the attribute
903896
pass
904897

@@ -914,7 +907,7 @@ def _extract_init_attributes(class_node: ast.ClassDef, context: EvaluationContex
914907
# Try to use the annotation
915908
if stmt.annotation:
916909
try:
917-
annotation = eval_node(stmt.annotation, temp_context)
910+
annotation = eval_node(stmt.annotation, context)
918911
resolved = _resolve_annotation(annotation, context)
919912
if resolved is not None:
920913
attributes[attr_name] = resolved
@@ -925,7 +918,7 @@ def _extract_init_attributes(class_node: ast.ClassDef, context: EvaluationContex
925918
# Try to infer from value
926919
if stmt.value:
927920
try:
928-
value = eval_node(stmt.value, temp_context)
921+
value = eval_node(stmt.value, context)
929922
if value is not None and value is not NOT_EVALUATED:
930923
attributes[attr_name] = value
931924
except Exception:
@@ -934,6 +927,49 @@ def _extract_init_attributes(class_node: ast.ClassDef, context: EvaluationContex
934927
return attributes
935928

936929

930+
def _extract_variables_from_module(
931+
module_node: Union[ast.Module, ast.Interactive, None], context: EvaluationContext
932+
):
933+
"""Extract and evaluate variable assignments from a module AST.
934+
935+
Scans the module for top-level variable assignments and evaluates them.
936+
This allows code like:
937+
938+
Args:
939+
module_node: The Module or Interactive AST node
940+
941+
Returns:
942+
Dictionary mapping variable names to their evaluated values
943+
"""
944+
variables = {}
945+
946+
if module_node is None:
947+
return variables
948+
949+
for stmt in module_node.body:
950+
if isinstance(stmt, ast.Assign):
951+
for target in stmt.targets:
952+
if isinstance(target, ast.Name):
953+
var_name = target.id
954+
try:
955+
value = eval_node(stmt.value, context)
956+
if value is not NOT_EVALUATED:
957+
variables[var_name] = value
958+
except Exception:
959+
pass
960+
elif isinstance(stmt, ast.AnnAssign):
961+
if isinstance(stmt.target, ast.Name) and stmt.value:
962+
var_name = stmt.target.id
963+
try:
964+
value = eval_node(stmt.value, context)
965+
if value is not NOT_EVALUATED:
966+
variables[var_name] = value
967+
except Exception:
968+
pass
969+
970+
return variables
971+
972+
937973
def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
938974
"""Evaluate return type of a given callable function.
939975

0 commit comments

Comments
 (0)