Skip to content

Commit afb128f

Browse files
committed
infer-complex-return-types
1 parent 6bf4b7e commit afb128f

2 files changed

Lines changed: 77 additions & 10 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
654654
continue
655655
if decorator is property:
656656
is_property = True
657-
658657
return_type = eval_node(node.returns, context=context)
659658

660659
if is_property:
@@ -674,9 +673,9 @@ def dummy_function(*args, **kwargs):
674673
if return_type is not None:
675674
dummy_function.__annotations__["return"] = return_type
676675
else:
677-
inferred_type = type(_infer_return_value(node, context))
678-
if inferred_type is not None:
679-
dummy_function.__annotations__["return"] = inferred_type
676+
inferred_return = _infer_return_value(node, context)
677+
if inferred_return is not None:
678+
dummy_function.__inferred_return__ = inferred_return
680679

681680
dummy_function.__name__ = node.name
682681
dummy_function.__node__ = node
@@ -847,6 +846,8 @@ def dummy_function(*args, **kwargs):
847846
return overridden_return_type
848847
return _create_duck_for_heap_type(func)
849848
else:
849+
if hasattr(func, "__inferred_return__"):
850+
return func.__inferred_return__
850851
return_type = _eval_return_type(func, node, context)
851852
if return_type is not NOT_EVALUATED:
852853
return return_type
@@ -865,19 +866,61 @@ def dummy_function(*args, **kwargs):
865866

866867

867868
def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
868-
"""Execute the function body to infer its return value."""
869+
"""Infer the return value(s) of a function by evaluating all return statements."""
870+
return_values = _collect_return_values(node.body, context)
871+
872+
if not return_values:
873+
return None
874+
if len(return_values) == 1:
875+
return return_values[0]
876+
877+
types = {type(v) for v in return_values}
878+
if len(types) == 1:
879+
t = next(iter(types))
880+
if t is dict:
881+
keys = set()
882+
for v in return_values:
883+
keys.update(v.keys())
884+
return _Duck(
885+
attributes=dict.fromkeys(dir(dict())), items={k: None for k in keys}
886+
)
887+
elif t in (list, set, tuple):
888+
return t()
889+
else:
890+
return return_values[0]
891+
else:
892+
attributes = set()
893+
for v in return_values:
894+
attributes.update(dir(v))
895+
return _Duck(attributes=dict.fromkeys(attributes))
896+
869897

870-
for stmt in node.body:
898+
def _collect_return_values(body, context):
899+
"""Recursively collect return values from a list of AST statements."""
900+
return_values = []
901+
for stmt in body:
871902
if isinstance(stmt, ast.Return):
872903
if stmt.value is None:
873-
return None
904+
continue
874905
try:
875906
value = eval_node(stmt.value, context)
876-
if value is not NOT_EVALUATED:
877-
return value
907+
if value is not None and value is not NOT_EVALUATED:
908+
return_values.append(value)
878909
except Exception:
879910
pass
880-
return None
911+
elif hasattr(stmt, "body") and isinstance(stmt.body, list):
912+
return_values.extend(_collect_return_values(stmt.body, context))
913+
if isinstance(stmt, ast.Try):
914+
for h in stmt.handlers:
915+
if hasattr(h, "body"):
916+
return_values.extend(_collect_return_values(h.body, context))
917+
if hasattr(stmt, "orelse"):
918+
return_values.extend(_collect_return_values(stmt.orelse, context))
919+
if hasattr(stmt, "finalbody"):
920+
return_values.extend(_collect_return_values(stmt.finalbody, context))
921+
if hasattr(stmt, "orelse") and isinstance(stmt.orelse, list):
922+
return_values.extend(_collect_return_values(stmt.orelse, context))
923+
return return_values
881924

882925

883926
def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):

tests/test_completer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,6 +2154,30 @@ def _(expected):
21542154
),
21552155
"append",
21562156
],
2157+
[
2158+
"\n".join(
2159+
[
2160+
"def string_or_int(flag):",
2161+
" if flag:",
2162+
" return 'test'",
2163+
" return 1",
2164+
"string_or_int().",
2165+
]
2166+
),
2167+
"capitalize",
2168+
],
2169+
[
2170+
"\n".join(
2171+
[
2172+
"def string_or_int(flag):",
2173+
" if flag:",
2174+
" return 'test'",
2175+
" return 1",
2176+
"string_or_int().",
2177+
]
2178+
),
2179+
"as_integer_ratio",
2180+
],
21572181
],
21582182
)
21592183
def test_undefined_variables(use_jedi, evaluation, code, insert_text):

0 commit comments

Comments
 (0)