@@ -867,7 +867,8 @@ def dummy_function(*args, **kwargs):
867867
868868def _infer_return_value (node : ast .FunctionDef , context : EvaluationContext ):
869869 """Infer the return value(s) of a function by evaluating all return statements."""
870- return_values = _collect_return_values (node .body , context )
870+ func_context = context .replace (transient_locals = context .transient_locals .copy ())
871+ return_values = _collect_return_values (node .body , func_context )
871872
872873 if not return_values :
873874 return None
@@ -896,9 +897,22 @@ def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
896897
897898
898899def _collect_return_values (body , context ):
899- """Recursively collect return values from a list of AST statements."""
900+ """Recursively collect return values from a list of AST statements.
901+
902+ For every assignment or annotated assignment, store them in context.transient_locals
903+ so that return statements can refer to them.
904+ """
900905 return_values = []
901906 for stmt in body :
907+ # Handle assignments
908+ if isinstance (stmt , ast .Assign ):
909+ _handle_assign (stmt , context )
910+ elif isinstance (stmt , ast .AnnAssign ):
911+ if stmt .simple :
912+ context .transient_locals [stmt .target .id ] = _resolve_annotation (
913+ eval_node (stmt .annotation , context ), context
914+ )
915+ # Handle return statements
902916 if isinstance (stmt , ast .Return ):
903917 if stmt .value is None :
904918 continue
0 commit comments