Skip to content

Commit c3cac5b

Browse files
committed
infer-from-init
1 parent 79f313a commit c3cac5b

1 file changed

Lines changed: 114 additions & 0 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,10 @@ def dummy_function(*args, **kwargs):
677677
class_context = context.replace(transient_locals=class_locals)
678678
for child_node in node.body:
679679
eval_node(child_node, class_context)
680+
# extract self.attribute assignments
681+
init_attributes = _extract_init_attributes(node, class_context)
682+
# Merge init attributes into class_locals
683+
class_locals.update(init_attributes)
680684
bases = tuple([eval_node(base, context) for base in node.bases])
681685
dummy_class = type(node.name, bases, class_locals)
682686
context.transient_locals[node.name] = dummy_class
@@ -853,6 +857,116 @@ def dummy_function(*args, **kwargs):
853857
return None
854858

855859

860+
def _extract_init_attributes(class_node: ast.ClassDef, context: EvaluationContext):
861+
"""Extract attribute assignments from __init__ method.
862+
863+
Looks for patterns like:
864+
self.attr = value
865+
self.attr: Type = value
866+
867+
And infers their types by evaluating the assigned values.
868+
869+
Returns dictionary mapping attribute names to their inferred Duck types
870+
"""
871+
attributes = {}
872+
init_method = None
873+
for node in class_node.body:
874+
if isinstance(node, ast.FunctionDef) and node.name == "__init__":
875+
init_method = node
876+
break
877+
878+
if not init_method:
879+
return attributes
880+
881+
temp_context = EvaluationContext(
882+
globals=context.globals,
883+
locals=context.locals,
884+
evaluation=context.evaluation,
885+
in_subscript=context.in_subscript,
886+
transient_locals={},
887+
)
888+
889+
for stmt in init_method.body:
890+
# Handle regular assignments: self.attr = value
891+
if isinstance(stmt, ast.Assign):
892+
for target in stmt.targets:
893+
if isinstance(target, ast.Attribute):
894+
if isinstance(target.value, ast.Name) and target.value.id == "self":
895+
attr_name = target.attr
896+
try:
897+
# Evaluate the assigned value
898+
value = eval_node(stmt.value, temp_context)
899+
if value is not None and value is not NOT_EVALUATED:
900+
inferred_type = _create_duck_from_value(value)
901+
if inferred_type is not None:
902+
attributes[attr_name] = inferred_type
903+
except Exception:
904+
# Skip the attribute
905+
pass
906+
907+
# Handle annotated assignments: self.attr: Type = value
908+
elif isinstance(stmt, ast.AnnAssign):
909+
if isinstance(stmt.target, ast.Attribute):
910+
if (
911+
isinstance(stmt.target.value, ast.Name)
912+
and stmt.target.value.id == "self"
913+
):
914+
attr_name = stmt.target.attr
915+
916+
# Try to use the annotation
917+
if stmt.annotation:
918+
try:
919+
annotation = eval_node(stmt.annotation, temp_context)
920+
resolved = _resolve_annotation(annotation, context)
921+
if resolved is not None:
922+
attributes[attr_name] = resolved
923+
continue
924+
except Exception:
925+
pass
926+
927+
# Try to infer from value
928+
if stmt.value:
929+
try:
930+
value = eval_node(stmt.value, temp_context)
931+
if value is not None and value is not NOT_EVALUATED:
932+
inferred_type = _create_duck_from_value(value)
933+
if inferred_type is not None:
934+
attributes[attr_name] = inferred_type
935+
except Exception:
936+
pass
937+
938+
return attributes
939+
940+
941+
def _create_duck_from_value(value):
942+
"""Create a Duck object from an actual runtime value."""
943+
if value is None or value is NOT_EVALUATED:
944+
return None
945+
value_type = type(value)
946+
if isinstance(value, dict):
947+
return _Duck(
948+
attributes=dict.fromkeys(dir(dict())), items=value if value else {}
949+
)
950+
elif isinstance(value, list):
951+
element_duck = None
952+
if value:
953+
element_duck = _create_duck_from_value(value[0])
954+
return _Duck(
955+
attributes=dict.fromkeys(dir(list())),
956+
items=_GetItemDuck(lambda: element_duck),
957+
)
958+
elif isinstance(value, set):
959+
return _Duck(attributes=dict.fromkeys(dir(set())))
960+
elif isinstance(value, tuple):
961+
return value
962+
elif isinstance(value, (str, int, float, bool, bytes)):
963+
return _Duck(attributes=dict.fromkeys(dir(value_type())))
964+
else:
965+
try:
966+
return _create_duck_for_heap_type(value_type)
967+
except Exception:
968+
return _Duck(attributes=dict.fromkeys(dir(value)))
969+
856970
def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
857971
"""Evaluate return type of a given callable function.
858972

0 commit comments

Comments
 (0)