Skip to content

Commit 076f5e7

Browse files
committed
Add type-guided partial evaluation for completion of unitialised variables
1 parent 78d65f0 commit 076f5e7

3 files changed

Lines changed: 288 additions & 41 deletions

File tree

IPython/core/completer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,10 @@ def _extract_code(self, line: str):
12241224
return line
12251225

12261226
def _attr_matches(
1227-
self, text: str, include_prefix: bool = True
1227+
self,
1228+
text: str,
1229+
include_prefix: bool = True,
1230+
context: Optional[CompletionContext] = None,
12281231
) -> tuple[Sequence[str], str]:
12291232
m2 = self._ATTR_MATCH_RE.match(text)
12301233
if not m2:
@@ -1237,7 +1240,19 @@ def _attr_matches(
12371240

12381241
obj = self._evaluate_expr(expr)
12391242
if obj is not_found:
1240-
return [], ""
1243+
if context:
1244+
# try to evaluate on full buffer
1245+
previous_lines = "\n".join(
1246+
context.full_text.split("\n")[: context.cursor_line]
1247+
)
1248+
if previous_lines:
1249+
all_code_lines_before_cursor = (
1250+
self._extract_code(previous_lines) + "\n" + expr
1251+
)
1252+
obj = self._evaluate_expr(all_code_lines_before_cursor)
1253+
1254+
if obj is not_found:
1255+
return [], ""
12411256

12421257
if self.limit_to__all__ and hasattr(obj, '__all__'):
12431258
words = get__all__entries(obj)
@@ -2678,7 +2693,9 @@ def python_matcher(self, context: CompletionContext) -> SimpleMatcherResult:
26782693
completion_type = self._determine_completion_context(text)
26792694
if completion_type == self._CompletionContextType.ATTRIBUTE:
26802695
try:
2681-
matches, fragment = self._attr_matches(text, include_prefix=False)
2696+
matches, fragment = self._attr_matches(
2697+
text, include_prefix=False, context=context
2698+
)
26822699
if text.endswith(".") and self.omit__names:
26832700
if self.omit__names == 1:
26842701
# true if txt is _not_ a __ name, false otherwise:

IPython/core/guarded_eval.py

Lines changed: 149 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import ast
2020
import builtins
2121
import collections
22+
import dataclasses
2223
import operator
2324
import sys
2425
import warnings
2526
from functools import cached_property
2627
from dataclasses import dataclass, field
27-
from types import MethodDescriptorType, ModuleType
28+
from types import MethodDescriptorType, ModuleType, MethodType
2829

2930
from IPython.utils.decorators import undoc
3031

@@ -353,7 +354,8 @@ class _DummyNamedTuple(NamedTuple):
353354
EvaluationPolicyName = Literal["forbidden", "minimal", "limited", "unsafe", "dangerous"]
354355

355356

356-
class EvaluationContext(NamedTuple):
357+
@dataclass
358+
class EvaluationContext:
357359
#: Local namespace
358360
locals: dict
359361
#: Global namespace
@@ -366,7 +368,13 @@ class EvaluationContext(NamedTuple):
366368
#: Auto import method
367369
auto_import: Callable[list[str], ModuleType] | None = None
368370
#: Overrides for evaluation policy
369-
policy_overrides: dict = {}
371+
policy_overrides: dict = field(default_factory=dict)
372+
#: Transient local namespace used to store mocks
373+
transient_locals: dict = field(default_factory=dict)
374+
375+
def replace(self, /, **changes):
376+
"""Return a new copy of the context, with specified changes"""
377+
return dataclasses.replace(self, **changes)
370378

371379

372380
class _IdentitySubscript:
@@ -414,14 +422,14 @@ def guarded_eval(code: str, context: EvaluationContext):
414422
locals_ = locals_.copy()
415423
locals_[SUBSCRIPT_MARKER] = IDENTITY_SUBSCRIPT
416424
code = SUBSCRIPT_MARKER + "[" + code + "]"
417-
context = EvaluationContext(**{**context._asdict(), **{"locals": locals_}})
425+
context = context.replace(locals=locals_)
418426

419427
if context.evaluation == "dangerous":
420428
return eval(code, context.globals, context.locals)
421429

422-
expression = ast.parse(code, mode="eval")
430+
node = ast.parse(code, mode="exec")
423431

424-
return eval_node(expression, context)
432+
return eval_node(node, context)
425433

426434

427435
BINARY_OP_DUNDERS: dict[type[ast.operator], tuple[str]] = {
@@ -524,6 +532,54 @@ def _validate_policy_overrides(
524532
return all_good
525533

526534

535+
def _handle_assign(node: ast.Assign, context: EvaluationContext):
536+
value = eval_node(node.value, context)
537+
transient_locals = context.transient_locals
538+
for target in node.targets:
539+
if isinstance(target, (ast.Tuple, ast.List)):
540+
# Handle unpacking assignment
541+
values = list(value)
542+
targets = target.elts
543+
starred = [i for i, t in enumerate(targets) if isinstance(t, ast.Starred)]
544+
545+
# Unified handling: treat no starred as starred at end
546+
star_or_last_idx = starred[0] if starred else len(targets)
547+
548+
# Before starred
549+
for i in range(star_or_last_idx):
550+
transient_locals[targets[i].id] = values[i]
551+
552+
# Starred if exists
553+
if starred:
554+
end = len(values) - (len(targets) - star_or_last_idx - 1)
555+
transient_locals[targets[star_or_last_idx].value.id] = values[
556+
star_or_last_idx:end
557+
]
558+
559+
# After starred
560+
for i in range(star_or_last_idx + 1, len(targets)):
561+
transient_locals[targets[i].id] = values[
562+
len(values) - (len(targets) - i)
563+
]
564+
else:
565+
transient_locals[target.id] = value
566+
return None
567+
568+
569+
def _extract_args_and_kwargs(node: ast.Call, context: EvaluationContext):
570+
args = [eval_node(arg, context) for arg in node.args]
571+
kwargs = {
572+
k: v
573+
for kw in node.keywords
574+
for k, v in (
575+
{kw.arg: eval_node(kw.value, context)}
576+
if kw.arg
577+
else eval_node(kw.value, context)
578+
).items()
579+
}
580+
return args, kwargs
581+
582+
527583
def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
528584
"""Evaluate AST node in provided context.
529585
@@ -555,8 +611,48 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
555611

556612
if node is None:
557613
return None
614+
if isinstance(node, (ast.Interactive, ast.Module)):
615+
result = None
616+
for child_node in node.body:
617+
result = eval_node(child_node, context)
618+
return result
619+
if isinstance(node, ast.FunctionDef):
620+
# we ignore body and only extract the return type
621+
# TODO:support decorators?
622+
def dummy_function(*args, **kwargs):
623+
pass
624+
625+
return_type = eval_node(node.returns, context=context)
626+
dummy_function.__annotations__["return"] = return_type
627+
dummy_function.__name__ = node.name
628+
dummy_function.__node__ = node
629+
context.transient_locals[node.name] = dummy_function
630+
return None
631+
if isinstance(node, ast.ClassDef):
632+
# TODO support decorators?
633+
class_locals = {}
634+
class_context = context.replace(transient_locals=class_locals)
635+
for child_node in node.body:
636+
eval_node(child_node, class_context)
637+
bases = tuple([eval_node(base, context) for base in node.bases])
638+
dummy_class = type(node.name, bases, class_locals)
639+
context.transient_locals[node.name] = dummy_class
640+
return None
641+
if isinstance(node, ast.Assign):
642+
return _handle_assign(node, context)
558643
if isinstance(node, ast.Expression):
559644
return eval_node(node.body, context)
645+
if isinstance(node, ast.Expr):
646+
return eval_node(node.value, context)
647+
if isinstance(node, ast.Pass):
648+
return None
649+
if isinstance(node, ast.Import):
650+
# TODO: populate transient_locals
651+
return None
652+
if isinstance(node, (ast.AugAssign, ast.Delete)):
653+
return None
654+
if isinstance(node, (ast.Global, ast.Nonlocal)):
655+
return None
560656
if isinstance(node, ast.BinOp):
561657
left = eval_node(node.left, context)
562658
right = eval_node(node.right, context)
@@ -676,9 +772,9 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
676772
return eval_node(node.orelse, context)
677773
if isinstance(node, ast.Call):
678774
func = eval_node(node.func, context)
679-
if policy.can_call(func) and not node.keywords:
680-
args = [eval_node(arg, context) for arg in node.args]
681-
return func(*args)
775+
if policy.can_call(func):
776+
args, kwargs = _extract_args_and_kwargs(node, context)
777+
return func(*args, **kwargs)
682778
if isclass(func):
683779
# this code path gets entered when calling class e.g. `MyClass()`
684780
# or `my_instance.__class__()` - in both cases `func` is `MyClass`.
@@ -717,21 +813,28 @@ def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext
717813
return NOT_EVALUATED
718814

719815

816+
def _eval_annotation(
817+
annotation: str,
818+
context: EvaluationContext,
819+
):
820+
return (
821+
_eval_node_name(annotation, context)
822+
if isinstance(annotation, str)
823+
else annotation
824+
)
825+
826+
720827
def _resolve_annotation(
721-
annotation,
828+
annotation: str,
722829
sig: Signature,
723830
func: Callable,
724831
node: ast.Call,
725832
context: EvaluationContext,
726833
):
727834
"""Resolve annotation created by user with `typing` module and custom objects."""
728-
annotation = (
729-
_eval_node_name(annotation, context)
730-
if isinstance(annotation, str)
731-
else annotation
732-
)
835+
annotation = _eval_annotation(annotation, context)
733836
origin = get_origin(annotation)
734-
if annotation is Self and hasattr(func, "__self__"):
837+
if annotation is Self and func and hasattr(func, "__self__"):
735838
return func.__self__
736839
elif origin is Literal:
737840
type_args = get_args(annotation)
@@ -741,12 +844,30 @@ def _resolve_annotation(
741844
return ""
742845
elif annotation is AnyStr:
743846
index = None
744-
for i, (key, value) in enumerate(sig.parameters.items()):
745-
if value.annotation is AnyStr:
746-
index = i
747-
break
748-
if index is not None and index < len(node.args):
749-
return eval_node(node.args[index], context)
847+
if hasattr(func, "__node__"):
848+
def_node = func.__node__
849+
for i, arg in enumerate(def_node.args.args):
850+
if not arg.annotation:
851+
continue
852+
annotation = _eval_annotation(arg.annotation.id, context)
853+
if annotation is AnyStr:
854+
index = i
855+
break
856+
is_bound_method = (
857+
isinstance(func, MethodType) and getattr(func, "__self__") is not None
858+
)
859+
if index and is_bound_method:
860+
index -= 1
861+
else:
862+
for i, (key, value) in enumerate(sig.parameters.items()):
863+
if value.annotation is AnyStr:
864+
index = i
865+
break
866+
if index is None:
867+
return None
868+
if index < 0 or index >= len(node.args):
869+
return None
870+
return eval_node(node.args[index], context)
750871
elif origin is TypeGuard:
751872
return False
752873
elif origin is Union:
@@ -779,6 +900,8 @@ def _resolve_annotation(
779900

780901
def _eval_node_name(node_id: str, context: EvaluationContext):
781902
policy = get_policy(context)
903+
if node_id in context.transient_locals:
904+
return context.transient_locals[node_id]
782905
if policy.allow_locals_access and node_id in context.locals:
783906
return context.locals[node_id]
784907
if policy.allow_globals_access and node_id in context.globals:
@@ -799,9 +922,8 @@ def _eval_node_name(node_id: str, context: EvaluationContext):
799922
def _eval_or_create_duck(duck_type, node: ast.Call, context: EvaluationContext):
800923
policy = get_policy(context)
801924
# if allow-listed builtin is on type annotation, instantiate it
802-
if policy.can_call(duck_type) and not node.keywords:
803-
args = [eval_node(arg, context) for arg in node.args]
804-
return duck_type(*args)
925+
if policy.can_call(duck_type):
926+
return duck_type()
805927
# if custom class is in type annotation, mock it
806928
return _create_duck_for_heap_type(duck_type)
807929

@@ -880,6 +1002,8 @@ def _list_methods(cls, source=None):
8801002
*_list_methods(str),
8811003
tuple,
8821004
*_list_methods(tuple),
1005+
bool,
1006+
*_list_methods(bool),
8831007
*NUMERICS,
8841008
*[method for numeric_cls in NUMERICS for method in _list_methods(numeric_cls)],
8851009
collections.deque,

0 commit comments

Comments
 (0)