Skip to content

Commit 5425002

Browse files
committed
infer-for-functions-too
1 parent 3f98bc8 commit 5425002

1 file changed

Lines changed: 30 additions & 11 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -663,15 +663,21 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
663663
return_type, context
664664
)
665665
else:
666-
inferred_type = _infer_property_return_type(node, context)
667-
context.transient_locals[node.name] = inferred_type
666+
inferred_duck_object = _get_duck_from_return_value(node, context)
667+
context.transient_locals[node.name] = inferred_duck_object
668668

669669
return None
670670

671671
def dummy_function(*args, **kwargs):
672672
pass
673673

674-
dummy_function.__annotations__["return"] = return_type
674+
if return_type is not None:
675+
dummy_function.__annotations__["return"] = return_type
676+
else:
677+
inferred_type = type(_infer_return_value(node, context))
678+
if inferred_type is not None:
679+
dummy_function.__annotations__["return"] = inferred_type
680+
675681
dummy_function.__name__ = node.name
676682
dummy_function.__node__ = node
677683
context.transient_locals[node.name] = dummy_function
@@ -858,28 +864,41 @@ def dummy_function(*args, **kwargs):
858864
return None
859865

860866

861-
def _infer_property_return_type(node: ast.FunctionDef, context: EvaluationContext):
862-
"""Infer the return type of a property by executing its body."""
867+
def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
868+
"""Execute the function body to infer its return value."""
863869
temp_context = EvaluationContext(
864870
globals=context.globals,
865-
locals=context.locals,
871+
locals=context.locals.copy(),
866872
evaluation=context.evaluation,
867873
in_subscript=context.in_subscript,
868874
transient_locals={},
869875
)
870876

871877
for stmt in node.body:
872-
if isinstance(stmt, ast.Return) and stmt.value is not None:
878+
if isinstance(stmt, ast.Return):
879+
if stmt.value is None:
880+
return None
873881
try:
874-
return_value = eval_node(stmt.value, temp_context)
875-
if return_value is not None and return_value is not NOT_EVALUATED:
876-
temp = _create_duck_from_value(return_value)
877-
return temp
882+
value = eval_node(stmt.value, temp_context)
883+
if value is not NOT_EVALUATED:
884+
return value
878885
except Exception:
879886
pass
880887
return None
881888

882889

890+
def _get_duck_from_return_value(node: ast.FunctionDef, context: EvaluationContext):
891+
"""Infer the 'duck type' from the first valid return value."""
892+
try:
893+
return_value = _infer_return_value(node, context)
894+
if return_value is not None and return_value is not NOT_EVALUATED:
895+
duck = _create_duck_from_value(return_value)
896+
return duck
897+
except Exception:
898+
pass
899+
return None
900+
901+
883902
def _create_duck_from_value(value):
884903
"""Create a Duck object from an actual runtime value."""
885904
if value is None or value is NOT_EVALUATED:

0 commit comments

Comments
 (0)