Skip to content

Commit 7cc8cdf

Browse files
committed
include-async-functions
1 parent 089818d commit 7cc8cdf

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,8 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
721721
for child_node in node.body:
722722
result = eval_node(child_node, context)
723723
return result
724-
if isinstance(node, ast.FunctionDef):
724+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
725+
is_async = isinstance(node, ast.AsyncFunctionDef)
725726
func_locals = context.transient_locals.copy()
726727
func_context = context.replace(transient_locals=func_locals)
727728
is_property = False
@@ -775,6 +776,7 @@ def dummy_function(*args, **kwargs):
775776

776777
dummy_function.__name__ = node.name
777778
dummy_function.__node__ = node
779+
dummy_function.__is_async__ = is_async
778780
context.transient_locals[node.name] = dummy_function
779781
return None
780782
if isinstance(node, ast.ClassDef):
@@ -792,6 +794,11 @@ def dummy_function(*args, **kwargs):
792794
dummy_class = type(node.name, bases, class_locals)
793795
context.transient_locals[node.name] = dummy_class
794796
return None
797+
if isinstance(node, ast.Await):
798+
value = eval_node(node.value, context)
799+
if hasattr(value, "__awaited_type__"):
800+
return value.__awaited_type__
801+
return value
795802
if isinstance(node, ast.Assign):
796803
return _handle_assign(node, context)
797804
if isinstance(node, ast.AnnAssign):
@@ -954,9 +961,17 @@ def dummy_function(*args, **kwargs):
954961
return overridden_return_type
955962
return _create_duck_for_heap_type(func)
956963
else:
957-
if hasattr(func, "__inferred_return__"):
958-
return func.__inferred_return__
964+
inferred_return = getattr(func, "__inferred_return__", None)
959965
return_type = _eval_return_type(func, node, context)
966+
if getattr(func, "__is_async__", False):
967+
awaited_type = (
968+
inferred_return if inferred_return is not None else return_type
969+
)
970+
coroutine_duck = ImpersonatingDuck()
971+
coroutine_duck.__awaited_type__ = awaited_type
972+
return coroutine_duck
973+
if inferred_return:
974+
return inferred_return
960975
if return_type is not NOT_EVALUATED:
961976
return return_type
962977
raise GuardRejection(

0 commit comments

Comments
 (0)