diff --git a/mypy/checker.py b/mypy/checker.py index 50290ba25c1bc..c214eb08b8abc 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -314,6 +314,25 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal return False +class ReturnTypeFinder(TraverserVisitor): + """Visitor to collect return types from return statements in a function body. + + This is used to infer return types for functions without explicit return type annotations. + """ + + def __init__(self, typemap: dict[Expression, Type]) -> None: + self.typemap = typemap + self.return_types: list[Type] = [] + + def visit_return_stmt(self, o: ReturnStmt) -> None: + if o.expr is not None and o.expr in self.typemap: + self.return_types.append(self.typemap[o.expr]) + + def visit_func_def(self, o: FuncDef) -> None: + # Skip nested functions + pass + + class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi): """Mypy type checker. @@ -1600,6 +1619,40 @@ def check_func_def( ): self.note(message_registry.EMPTY_BODY_ABSTRACT, defn) + # Infer return type from return statements if function has no explicit return type annotation + if isinstance(item, FuncDef) and isinstance(typ, CallableType): + + def is_unannotated_any(t: Type) -> bool: + if not isinstance(t, ProperType): + return False + return isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated + + ret_type_proper = get_proper_type(typ.ret_type) + # Only infer for functions without explicit return type annotations + # Skip generators and coroutines as they have special return type handling + if ( + is_unannotated_any(ret_type_proper) + and not defn.is_generator + and not defn.is_coroutine + and not self.dynamic_funcs[-1] + and item.body is not None + ): + # Collect return types from return statements + # Use the master type map (first in stack) where final types are stored + # At this point in type checking, return statement types should be in the master map + finder = ReturnTypeFinder(self._type_maps[0]) + item.body.accept(finder) + return_types_list = finder.return_types + + if return_types_list: + # Create union of all return types + inferred_ret_type = make_simplified_union(return_types_list) + # Update the function's return type + typ = typ.copy_modified(ret_type=inferred_ret_type) + item.type = typ + # Update the return_types stack as well + self.return_types[-1] = inferred_ret_type + self.return_types.pop() self.binder = old_binder diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 5b1a59e359c19..d073c68c7e51d 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -88,6 +88,29 @@ def f() -> Iterator[int]: return "foo" # E: No return value expected [out] +[case testInferReturnTypeFromReturnStatements] +# Test that mypy infers return type from return statements when function has no explicit return type annotation +def f(x: int): + if x > 0: + return "positive" + else: + return 0 + +reveal_type(f(1)) # N: Revealed type is "builtins.str | builtins.int" + +def g(x: bool): + return x + +reveal_type(g(True)) # N: Revealed type is "builtins.bool" + +def h(x: int): + if x > 0: + return "positive" + return None + +reveal_type(h(1)) # N: Revealed type is "builtins.str | None" + +[out] -- If statement -- ------------