|
148 | 148 | from mypy.operators import flip_ops, int_op_to_method, neg_ops |
149 | 149 | from mypy.options import PRECISE_TUPLE_TYPES, Options |
150 | 150 | from mypy.patterns import AsPattern, StarredPattern |
| 151 | +from mypy.traverser import TraverserVisitor |
151 | 152 | from mypy.plugin import Plugin |
152 | 153 | from mypy.plugins import dataclasses as dataclasses_plugin |
153 | 154 | from mypy.scope import Scope |
@@ -314,6 +315,25 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal |
314 | 315 | return False |
315 | 316 |
|
316 | 317 |
|
| 318 | +class ReturnTypeFinder(TraverserVisitor): |
| 319 | + """Visitor to collect return types from return statements in a function body. |
| 320 | + |
| 321 | + This is used to infer return types for functions without explicit return type annotations. |
| 322 | + """ |
| 323 | + |
| 324 | + def __init__(self, typemap: dict[Expression, Type]) -> None: |
| 325 | + self.typemap = typemap |
| 326 | + self.return_types: list[Type] = [] |
| 327 | + |
| 328 | + def visit_return_stmt(self, o: ReturnStmt) -> None: |
| 329 | + if o.expr is not None and o.expr in self.typemap: |
| 330 | + self.return_types.append(self.typemap[o.expr]) |
| 331 | + |
| 332 | + def visit_func_def(self, o: FuncDef) -> None: |
| 333 | + # Skip nested functions |
| 334 | + pass |
| 335 | + |
| 336 | + |
317 | 337 | class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi): |
318 | 338 | """Mypy type checker. |
319 | 339 |
|
@@ -1620,20 +1640,6 @@ def is_unannotated_any(t: Type) -> bool: |
1620 | 1640 | # Collect return types from return statements |
1621 | 1641 | # Use the master type map (first in stack) where final types are stored |
1622 | 1642 | # At this point in type checking, return statement types should be in the master map |
1623 | | - # Inline get_return_types to avoid circular import with mypy.suggestions |
1624 | | - class ReturnTypeFinder(TraverserVisitor): |
1625 | | - def __init__(self, typemap: dict[Expression, Type]) -> None: |
1626 | | - self.typemap = typemap |
1627 | | - self.return_types: list[Type] = [] |
1628 | | - |
1629 | | - def visit_return_stmt(self, o: ReturnStmt) -> None: |
1630 | | - if o.expr is not None and o.expr in self.typemap: |
1631 | | - self.return_types.append(self.typemap[o.expr]) |
1632 | | - |
1633 | | - def visit_func_def(self, o: FuncDef) -> None: |
1634 | | - # Skip nested functions |
1635 | | - pass |
1636 | | - |
1637 | 1643 | finder = ReturnTypeFinder(self._type_maps[0]) |
1638 | 1644 | if item.body: |
1639 | 1645 | item.body.accept(finder) |
|
0 commit comments