@@ -654,7 +654,6 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
654654 continue
655655 if decorator is property :
656656 is_property = True
657-
658657 return_type = eval_node (node .returns , context = context )
659658
660659 if is_property :
@@ -674,9 +673,9 @@ def dummy_function(*args, **kwargs):
674673 if return_type is not None :
675674 dummy_function .__annotations__ ["return" ] = return_type
676675 else :
677- inferred_type = type ( _infer_return_value (node , context ) )
678- if inferred_type is not None :
679- dummy_function .__annotations__ [ "return" ] = inferred_type
676+ inferred_return = _infer_return_value (node , context )
677+ if inferred_return is not None :
678+ dummy_function .__inferred_return__ = inferred_return
680679
681680 dummy_function .__name__ = node .name
682681 dummy_function .__node__ = node
@@ -847,6 +846,8 @@ def dummy_function(*args, **kwargs):
847846 return overridden_return_type
848847 return _create_duck_for_heap_type (func )
849848 else :
849+ if hasattr (func , "__inferred_return__" ):
850+ return func .__inferred_return__
850851 return_type = _eval_return_type (func , node , context )
851852 if return_type is not NOT_EVALUATED :
852853 return return_type
@@ -865,19 +866,61 @@ def dummy_function(*args, **kwargs):
865866
866867
867868def _infer_return_value (node : ast .FunctionDef , context : EvaluationContext ):
868- """Execute the function body to infer its return value."""
869+ """Infer the return value(s) of a function by evaluating all return statements."""
870+ return_values = _collect_return_values (node .body , context )
871+
872+ if not return_values :
873+ return None
874+ if len (return_values ) == 1 :
875+ return return_values [0 ]
876+
877+ types = {type (v ) for v in return_values }
878+ if len (types ) == 1 :
879+ t = next (iter (types ))
880+ if t is dict :
881+ keys = set ()
882+ for v in return_values :
883+ keys .update (v .keys ())
884+ return _Duck (
885+ attributes = dict .fromkeys (dir (dict ())), items = {k : None for k in keys }
886+ )
887+ elif t in (list , set , tuple ):
888+ return t ()
889+ else :
890+ return return_values [0 ]
891+ else :
892+ attributes = set ()
893+ for v in return_values :
894+ attributes .update (dir (v ))
895+ return _Duck (attributes = dict .fromkeys (attributes ))
896+
869897
870- for stmt in node .body :
898+ def _collect_return_values (body , context ):
899+ """Recursively collect return values from a list of AST statements."""
900+ return_values = []
901+ for stmt in body :
871902 if isinstance (stmt , ast .Return ):
872903 if stmt .value is None :
873- return None
904+ continue
874905 try :
875906 value = eval_node (stmt .value , context )
876- if value is not NOT_EVALUATED :
877- return value
907+ if value is not None and value is not NOT_EVALUATED :
908+ return_values . append ( value )
878909 except Exception :
879910 pass
880- return None
911+ elif hasattr (stmt , "body" ) and isinstance (stmt .body , list ):
912+ return_values .extend (_collect_return_values (stmt .body , context ))
913+ if isinstance (stmt , ast .Try ):
914+ for h in stmt .handlers :
915+ if hasattr (h , "body" ):
916+ return_values .extend (_collect_return_values (h .body , context ))
917+ if hasattr (stmt , "orelse" ):
918+ return_values .extend (_collect_return_values (stmt .orelse , context ))
919+ if hasattr (stmt , "finalbody" ):
920+ return_values .extend (_collect_return_values (stmt .finalbody , context ))
921+ if hasattr (stmt , "orelse" ) and isinstance (stmt .orelse , list ):
922+ return_values .extend (_collect_return_values (stmt .orelse , context ))
923+ return return_values
881924
882925
883926def _eval_return_type (func : Callable , node : ast .Call , context : EvaluationContext ):
0 commit comments