Skip to content

Commit 5e89487

Browse files
committed
Handle nested definitions
1 parent 94fa981 commit 5e89487

4 files changed

Lines changed: 148 additions & 14 deletions

File tree

mypy/build.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@
103103
ImportBase,
104104
ImportFrom,
105105
MypyFile,
106-
SymbolNode,
107106
SymbolTable,
108107
)
109108
from mypy.partially_defined import PossiblyUndefinedVariableVisitor
110109
from mypy.semanal import SemanticAnalyzer
111110
from mypy.semanal_pass1 import SemanticAnalyzerPreAnalysis
111+
from mypy.traverser import find_definitions
112112
from mypy.util import (
113113
DecodeError,
114114
decode_python_encoding,
@@ -1085,22 +1085,20 @@ def resolve_location(self, graph: dict[str, State], fullname: str) -> Context |
10851085
source = decode_python_encoding(state.manager.fscache.read(path))
10861086
tree = parse(source, state.path, state.id, state.manager.errors, state.options)
10871087
self.extra_trees[state.id] = tree
1088-
defs = tree.defs
1088+
statements = tree.defs
10891089
while prefix:
10901090
part = prefix.pop(0)
1091-
for defn in defs:
1092-
if not isinstance(defn, ClassDef):
1091+
for statement in statements:
1092+
defs = find_definitions(statement, part)
1093+
if not defs or not isinstance((defn := defs[0]), ClassDef):
10931094
continue
1094-
if defn.name == part:
1095-
defs = defn.defs.body
1096-
break
1095+
statements = defn.defs.body
1096+
break
10971097
else:
10981098
return None
1099-
for defn in defs:
1100-
# TODO: support more kinds of locations (like assignment statements).
1101-
# the latter will be helpful for type old-style aliases.
1102-
if isinstance(defn, SymbolNode) and defn.name == name:
1103-
return defn
1099+
for statement in statements:
1100+
if defs := find_definitions(statement, name):
1101+
return defs[0]
11041102
return None
11051103

11061104
def verbosity(self) -> int:

mypy/plugins/functools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
307307

308308
# Technically, we should set definition to None here, since it will not be recovered
309309
# on warm cache runs in fixup.py. This however may hide some helpful info in error
310-
# messages, so we are keeping it for now.
310+
# messages, so we are keeping it for now. See also issue #20640.
311311
partially_applied = fn_type.copy_modified(
312312
arg_types=partial_types,
313313
arg_kinds=partial_kinds,

mypy/traverser.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
SetExpr,
6767
SliceExpr,
6868
StarExpr,
69+
Statement,
6970
StrExpr,
7071
SuperExpr,
7172
TempNode,
@@ -96,7 +97,7 @@
9697
StarredPattern,
9798
ValuePattern,
9899
)
99-
from mypy.visitor import NodeVisitor
100+
from mypy.visitor import NodeVisitor, StatementVisitor
100101

101102

102103
@trait
@@ -108,6 +109,10 @@ class TraverserVisitor(NodeVisitor[None]):
108109
should override visit methods to perform actions during
109110
traversal. Calling the superclass method allows reusing the
110111
traversal implementation.
112+
113+
TODO: split this into more limited visitor (e.g. statements-only etc).
114+
This will improve performance since in many cases we don't need to recurse
115+
all the way down in various visitors that subclass this.
111116
"""
112117

113118
def __init__(self) -> None:
@@ -1084,3 +1089,115 @@ def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
10841089

10851090
def visit_yield_from_expr(self, expr: YieldFromExpr) -> None:
10861091
self.yield_from_expressions.append((expr, self.in_assignment))
1092+
1093+
1094+
def find_definitions(o: Statement, name: str) -> list[Statement]:
1095+
visitor = DefinitionSeeker(name)
1096+
o.accept(visitor)
1097+
return visitor.found
1098+
1099+
1100+
class DefinitionSeeker(StatementVisitor[None]):
1101+
def __init__(self, name: str) -> None:
1102+
self.name = name
1103+
self.found: list[Statement] = []
1104+
1105+
def visit_assignment_stmt(self, o: AssignmentStmt, /) -> None:
1106+
# TODO: support more kinds of locations (like assignment statements).
1107+
# the latter will be helpful for type old-style aliases.
1108+
pass
1109+
1110+
def visit_for_stmt(self, o: ForStmt, /) -> None:
1111+
o.body.accept(self)
1112+
if o.else_body:
1113+
o.else_body.accept(self)
1114+
1115+
def visit_with_stmt(self, o: WithStmt, /) -> None:
1116+
o.body.accept(self)
1117+
1118+
def visit_del_stmt(self, o: DelStmt, /) -> None:
1119+
pass
1120+
1121+
def visit_func_def(self, o: FuncDef, /) -> None:
1122+
if o.name == self.name:
1123+
self.found.append(o)
1124+
1125+
def visit_overloaded_func_def(self, o: OverloadedFuncDef, /) -> None:
1126+
if o.name == self.name:
1127+
self.found.append(o)
1128+
1129+
def visit_class_def(self, o: ClassDef, /) -> None:
1130+
if o.name == self.name:
1131+
self.found.append(o)
1132+
1133+
def visit_global_decl(self, o: GlobalDecl, /) -> None:
1134+
pass
1135+
1136+
def visit_nonlocal_decl(self, o: NonlocalDecl, /) -> None:
1137+
pass
1138+
1139+
def visit_decorator(self, o: Decorator, /) -> None:
1140+
if o.name == self.name:
1141+
self.found.append(o)
1142+
1143+
def visit_import(self, o: Import, /) -> None:
1144+
pass
1145+
1146+
def visit_import_from(self, o: ImportFrom, /) -> None:
1147+
pass
1148+
1149+
def visit_import_all(self, o: ImportAll, /) -> None:
1150+
pass
1151+
1152+
def visit_block(self, o: Block, /) -> None:
1153+
for s in o.body:
1154+
s.accept(self)
1155+
1156+
def visit_expression_stmt(self, o: ExpressionStmt, /) -> None:
1157+
pass
1158+
1159+
def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt, /) -> None:
1160+
pass
1161+
1162+
def visit_while_stmt(self, o: WhileStmt, /) -> None:
1163+
o.body.accept(self)
1164+
if o.else_body:
1165+
o.else_body.accept(self)
1166+
1167+
def visit_return_stmt(self, o: ReturnStmt, /) -> None:
1168+
pass
1169+
1170+
def visit_assert_stmt(self, o: AssertStmt, /) -> None:
1171+
pass
1172+
1173+
def visit_if_stmt(self, o: IfStmt, /) -> None:
1174+
for b in o.body:
1175+
b.accept(self)
1176+
if o.else_body:
1177+
o.else_body.accept(self)
1178+
1179+
def visit_break_stmt(self, o: BreakStmt, /) -> None:
1180+
pass
1181+
1182+
def visit_continue_stmt(self, o: ContinueStmt, /) -> None:
1183+
pass
1184+
1185+
def visit_pass_stmt(self, o: PassStmt, /) -> None:
1186+
pass
1187+
1188+
def visit_raise_stmt(self, o: RaiseStmt, /) -> None:
1189+
pass
1190+
1191+
def visit_try_stmt(self, o: TryStmt, /) -> None:
1192+
o.body.accept(self)
1193+
if o.else_body is not None:
1194+
o.else_body.accept(self)
1195+
if o.finally_body is not None:
1196+
o.finally_body.accept(self)
1197+
1198+
def visit_match_stmt(self, o: MatchStmt, /) -> None:
1199+
for b in o.bodies:
1200+
b.accept(self)
1201+
1202+
def visit_type_alias_stmt(self, o: TypeAliasStmt, /) -> None:
1203+
pass

test-data/unit/check-incremental.test

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7730,3 +7730,22 @@ tmp/b.py:2: note: "lol" defined here
77307730
[out4]
77317731
tmp/a.py:2: error: Unexpected keyword argument "uhhhh" for "lol"
77327732
tmp/b.py:2: note: "lol" defined here
7733+
7734+
[case testCachedUnexpectedKeywordArgumentNested]
7735+
import a
7736+
[file a.py]
7737+
import b
7738+
b.lol(uhhhh=12) # tweak
7739+
[file a.py.2]
7740+
import b
7741+
b.lol(uhhhh=12)
7742+
[file b.py]
7743+
while True:
7744+
if True:
7745+
def lol() -> None: pass
7746+
[out]
7747+
tmp/a.py:2: error: Unexpected keyword argument "uhhhh" for "lol"
7748+
tmp/b.py:3: note: "lol" defined here
7749+
[out2]
7750+
tmp/a.py:2: error: Unexpected keyword argument "uhhhh" for "lol"
7751+
tmp/b.py:3: note: "lol" defined here

0 commit comments

Comments
 (0)