Skip to content

Commit 69a84ca

Browse files
committed
make-self.attributes-class-locals
1 parent 051f45d commit 69a84ca

2 files changed

Lines changed: 109 additions & 23 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from copy import copy
23
from inspect import isclass, signature, Signature, getmodule
34
from typing import (
@@ -396,6 +397,8 @@ class EvaluationContext:
396397
policy_overrides: dict = field(default_factory=dict)
397398
#: Transient local namespace used to store mocks
398399
transient_locals: dict = field(default_factory=dict)
400+
#: Tranisents of class level
401+
class_transients: dict | None = None
399402

400403
def replace(self, /, **changes):
401404
"""Return a new copy of the context, with specified changes"""
@@ -560,6 +563,7 @@ def _validate_policy_overrides(
560563
def _handle_assign(node: ast.Assign, context: EvaluationContext):
561564
value = eval_node(node.value, context)
562565
transient_locals = context.transient_locals
566+
class_transients = getattr(context, "class_transients", None)
563567
for target in node.targets:
564568
if isinstance(target, (ast.Tuple, ast.List)):
565569
# Handle unpacking assignment
@@ -573,21 +577,55 @@ def _handle_assign(node: ast.Assign, context: EvaluationContext):
573577
# Before starred
574578
for i in range(star_or_last_idx):
575579
transient_locals[targets[i].id] = values[i]
580+
# Check for self.x assignment
581+
if class_transients is not None and hasattr(targets[i], "ctx"):
582+
if (
583+
isinstance(targets[i], ast.Attribute)
584+
and isinstance(targets[i].value, ast.Name)
585+
and targets[i].value.id == "self"
586+
):
587+
class_transients[targets[i].attr] = values[i]
576588

577589
# Starred if exists
578590
if starred:
579591
end = len(values) - (len(targets) - star_or_last_idx - 1)
580592
transient_locals[targets[star_or_last_idx].value.id] = values[
581593
star_or_last_idx:end
582594
]
595+
if (
596+
class_transients is not None
597+
and isinstance(targets[star_or_last_idx], ast.Attribute)
598+
and isinstance(targets[star_or_last_idx].value, ast.Name)
599+
and targets[star_or_last_idx].value.id == "self"
600+
):
601+
class_transients[targets[star_or_last_idx].attr] = values[
602+
star_or_last_idx:end
603+
]
583604

584605
# After starred
585606
for i in range(star_or_last_idx + 1, len(targets)):
586607
transient_locals[targets[i].id] = values[
587608
len(values) - (len(targets) - i)
588609
]
610+
if (
611+
class_transients is not None
612+
and isinstance(targets[i], ast.Attribute)
613+
and isinstance(targets[i].value, ast.Name)
614+
and targets[i].value.id == "self"
615+
):
616+
class_transients[targets[i].attr] = values[
617+
len(values) - (len(targets) - i)
618+
]
589619
else:
590-
transient_locals[target.id] = value
620+
if (
621+
isinstance(target, ast.Attribute)
622+
and isinstance(target.value, ast.Name)
623+
and target.value.id == "self"
624+
):
625+
if class_transients is not None:
626+
class_transients[target.attr] = value
627+
elif hasattr(target, "id"):
628+
transient_locals[target.id] = value
591629
return None
592630

593631

@@ -643,6 +681,10 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
643681
return result
644682
if isinstance(node, ast.FunctionDef):
645683
# we ignore body and only extract the return type
684+
func_locals = context.transient_locals.copy()
685+
func_context = context.replace(transient_locals=func_locals)
686+
for child_node in node.body:
687+
eval_node(child_node, func_context)
646688
is_property = False
647689

648690
for decorator_node in node.decorator_list:
@@ -662,7 +704,7 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
662704
return_type, context
663705
)
664706
else:
665-
return_value = _infer_return_value(node, context)
707+
return_value = _infer_return_value(node, func_context)
666708
context.transient_locals[node.name] = return_value
667709

668710
return None
@@ -673,7 +715,7 @@ def dummy_function(*args, **kwargs):
673715
if return_type is not None:
674716
dummy_function.__annotations__["return"] = return_type
675717
else:
676-
inferred_return = _infer_return_value(node, context)
718+
inferred_return = _infer_return_value(node, func_context)
677719
if inferred_return is not None:
678720
dummy_function.__inferred_return__ = inferred_return
679721

@@ -685,6 +727,7 @@ def dummy_function(*args, **kwargs):
685727
# TODO support class decorators?
686728
class_locals = context.transient_locals.copy()
687729
class_context = context.replace(transient_locals=class_locals)
730+
class_context.class_transients = class_locals
688731
for child_node in node.body:
689732
eval_node(child_node, class_context)
690733
bases = tuple([eval_node(base, context) for base in node.bases])
@@ -694,12 +737,19 @@ def dummy_function(*args, **kwargs):
694737
if isinstance(node, ast.Assign):
695738
return _handle_assign(node, context)
696739
if isinstance(node, ast.AnnAssign):
697-
if not node.simple:
698-
# for now only handle simple annotations
699-
return None
700-
context.transient_locals[node.target.id] = _resolve_annotation(
701-
eval_node(node.annotation, context), context
702-
)
740+
if node.simple:
741+
value = _resolve_annotation(eval_node(node.annotation, context), context)
742+
context.transient_locals[node.target.id] = value
743+
# Handle non-simple annotated assignments only for self.x: type = value
744+
class_transients = getattr(context, "class_transients", None)
745+
if (
746+
class_transients is not None
747+
and isinstance(node.target, ast.Attribute)
748+
and isinstance(node.target.value, ast.Name)
749+
and node.target.value.id == "self"
750+
):
751+
value = _resolve_annotation(eval_node(node.annotation, context), context)
752+
class_transients[node.target.attr] = value
703753
return None
704754
if isinstance(node, ast.Expression):
705755
return eval_node(node.body, context)
@@ -897,22 +947,9 @@ def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
897947

898948

899949
def _collect_return_values(body, context):
900-
"""Recursively collect return values from a list of AST statements.
901-
902-
For every assignment or annotated assignment, store them in context.transient_locals
903-
so that return statements can refer to them.
904-
"""
950+
"""Recursively collect return values from a list of AST statements."""
905951
return_values = []
906952
for stmt in body:
907-
# Handle assignments
908-
if isinstance(stmt, ast.Assign):
909-
_handle_assign(stmt, context)
910-
elif isinstance(stmt, ast.AnnAssign):
911-
if stmt.simple:
912-
context.transient_locals[stmt.target.id] = _resolve_annotation(
913-
eval_node(stmt.annotation, context), context
914-
)
915-
# Handle return statements
916953
if isinstance(stmt, ast.Return):
917954
if stmt.value is None:
918955
continue

tests/test_completer.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2189,6 +2189,55 @@ def _(expected):
21892189
),
21902190
"append",
21912191
],
2192+
[
2193+
"\n".join(
2194+
[
2195+
"class NotYetDefined:",
2196+
" def __init__(self):",
2197+
" self.test = []",
2198+
"instance = NotYetDefined()",
2199+
"instance.",
2200+
]
2201+
),
2202+
"test",
2203+
],
2204+
[
2205+
"\n".join(
2206+
[
2207+
"class NotYetDefined:",
2208+
" def __init__(self):",
2209+
" self.test = []",
2210+
"instance = NotYetDefined()",
2211+
"instance.test.",
2212+
]
2213+
),
2214+
"append",
2215+
],
2216+
[
2217+
"\n".join(
2218+
[
2219+
"class NotYetDefined:",
2220+
" def __init__(self):",
2221+
" self.test:str = []",
2222+
"instance = NotYetDefined()",
2223+
"instance.test.",
2224+
]
2225+
),
2226+
"capitalize",
2227+
],
2228+
[
2229+
"\n".join(
2230+
[
2231+
"l = []",
2232+
"class NotYetDefined:",
2233+
" def __init__(self):",
2234+
" self.test = l",
2235+
"instance = NotYetDefined()",
2236+
"instance.test.",
2237+
]
2238+
),
2239+
"append",
2240+
],
21922241
],
21932242
)
21942243
def test_undefined_variables(use_jedi, evaluation, code, insert_text):

0 commit comments

Comments
 (0)