|
8 | 8 |
|
9 | 9 | import ast |
10 | 10 | import tokenize |
| 11 | +from collections.abc import Iterator |
11 | 12 | from dataclasses import dataclass, field |
12 | 13 | from typing import TYPE_CHECKING, Literal, NamedTuple |
13 | 14 |
|
@@ -86,6 +87,8 @@ class _ModuleWalkState: |
86 | 87 | name_nodes: list[ast.Name] = field(default_factory=list) |
87 | 88 | attr_nodes: list[ast.Attribute] = field(default_factory=list) |
88 | 89 | exported_names: set[str] = field(default_factory=set) |
| 90 | + lazy_export_bindings: dict[str, set[str]] = field(default_factory=dict) |
| 91 | + has_module_getattr: bool = False |
89 | 92 | protocol_symbol_aliases: set[str] = field(default_factory=lambda: {"Protocol"}) |
90 | 93 | protocol_module_aliases: set[str] = field( |
91 | 94 | default_factory=lambda: set(_PROTOCOL_MODULE_NAMES) |
@@ -185,6 +188,21 @@ def _string_literals_from_export_value(value: ast.AST) -> tuple[str, ...]: |
185 | 188 | return () |
186 | 189 |
|
187 | 190 |
|
| 191 | +def _string_mapping_from_literal_dict(value: ast.AST) -> dict[str, str]: |
| 192 | + if not isinstance(value, ast.Dict): |
| 193 | + return {} |
| 194 | + mapping: dict[str, str] = {} |
| 195 | + for key, val in zip(value.keys, value.values, strict=True): |
| 196 | + if ( |
| 197 | + isinstance(key, ast.Constant) |
| 198 | + and isinstance(key.value, str) |
| 199 | + and isinstance(val, ast.Constant) |
| 200 | + and isinstance(val.value, str) |
| 201 | + ): |
| 202 | + mapping[key.value] = val.value |
| 203 | + return mapping |
| 204 | + |
| 205 | + |
188 | 206 | def _collect_all_export_node(node: ast.AST, state: _ModuleWalkState) -> None: |
189 | 207 | match node: |
190 | 208 | case ast.Assign(targets=targets, value=value): |
@@ -216,11 +234,128 @@ def _collect_all_export_node(node: ast.AST, state: _ModuleWalkState) -> None: |
216 | 234 | pass |
217 | 235 |
|
218 | 236 |
|
| 237 | +def _collect_lazy_export_node(node: ast.AST, state: _ModuleWalkState) -> None: |
| 238 | + match node: |
| 239 | + case ast.Assign(targets=targets, value=value): |
| 240 | + names = {target.id for target in targets if isinstance(target, ast.Name)} |
| 241 | + case ast.AnnAssign(target=ast.Name(id=name), value=value): |
| 242 | + names = {name} |
| 243 | + case ( |
| 244 | + ast.FunctionDef(name="__getattr__") |
| 245 | + | ast.AsyncFunctionDef(name="__getattr__") |
| 246 | + ): |
| 247 | + state.has_module_getattr = True |
| 248 | + return |
| 249 | + case _: |
| 250 | + return |
| 251 | + if "_EXPORTS" not in names or value is None: |
| 252 | + return |
| 253 | + for exported_name, module_path in _string_mapping_from_literal_dict(value).items(): |
| 254 | + state.lazy_export_bindings.setdefault(exported_name, set()).add(module_path) |
| 255 | + |
| 256 | + |
219 | 257 | def _collect_module_all_exports(tree: ast.AST, state: _ModuleWalkState) -> None: |
220 | 258 | if not isinstance(tree, ast.Module): |
221 | 259 | return |
222 | 260 | for statement in tree.body: |
223 | 261 | _collect_all_export_node(statement, state) |
| 262 | + _collect_lazy_export_node(statement, state) |
| 263 | + |
| 264 | + |
| 265 | +def _literal_getattr_name(value: ast.AST | None) -> str | None: |
| 266 | + if not isinstance(value, ast.Call): |
| 267 | + return None |
| 268 | + if not isinstance(value.func, ast.Name) or value.func.id != "getattr": |
| 269 | + return None |
| 270 | + if len(value.args) < 2: |
| 271 | + return None |
| 272 | + attr_arg = value.args[1] |
| 273 | + if not isinstance(attr_arg, ast.Constant) or not isinstance(attr_arg.value, str): |
| 274 | + return None |
| 275 | + if attr_arg.value.isidentifier(): |
| 276 | + return attr_arg.value |
| 277 | + return None |
| 278 | + |
| 279 | + |
| 280 | +def _iter_runtime_callable_scopes( |
| 281 | + tree: ast.AST, |
| 282 | +) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]: |
| 283 | + if not isinstance(tree, ast.Module): |
| 284 | + return |
| 285 | + stack = list(reversed(tree.body)) |
| 286 | + while stack: |
| 287 | + node = stack.pop() |
| 288 | + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef): |
| 289 | + yield node |
| 290 | + continue |
| 291 | + if isinstance(node, ast.ClassDef): |
| 292 | + stack.extend(reversed(node.body)) |
| 293 | + |
| 294 | + |
| 295 | +def _iter_scope_body_nodes(body: list[ast.stmt]) -> Iterator[ast.AST]: |
| 296 | + stack: list[ast.AST] = list(reversed(body)) |
| 297 | + while stack: |
| 298 | + node = stack.pop() |
| 299 | + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef): |
| 300 | + continue |
| 301 | + yield node |
| 302 | + stack.extend(reversed(list(ast.iter_child_nodes(node)))) |
| 303 | + |
| 304 | + |
| 305 | +def _dynamic_getattr_names_from_scope( |
| 306 | + node: ast.FunctionDef | ast.AsyncFunctionDef, |
| 307 | +) -> set[str]: |
| 308 | + getattr_bindings: dict[str, str] = {} |
| 309 | + callable_guards: set[str] = set() |
| 310 | + called_locals: set[str] = set() |
| 311 | + for scope_node in _iter_scope_body_nodes(node.body): |
| 312 | + match scope_node: |
| 313 | + case ast.Assign(targets=targets, value=value): |
| 314 | + attr_name = _literal_getattr_name(value) |
| 315 | + if attr_name is not None: |
| 316 | + for target in targets: |
| 317 | + if isinstance(target, ast.Name): |
| 318 | + getattr_bindings[target.id] = attr_name |
| 319 | + case ast.AnnAssign(target=ast.Name(id=name), value=value): |
| 320 | + attr_name = _literal_getattr_name(value) |
| 321 | + if attr_name is not None: |
| 322 | + getattr_bindings[name] = attr_name |
| 323 | + case ast.Call( |
| 324 | + func=ast.Name(id="callable"), |
| 325 | + args=[ast.Name(id=name), *_], |
| 326 | + ): |
| 327 | + callable_guards.add(name) |
| 328 | + case ast.Call(func=ast.Name(id=name)): |
| 329 | + called_locals.add(name) |
| 330 | + case _: |
| 331 | + pass |
| 332 | + return { |
| 333 | + attr_name |
| 334 | + for local_name, attr_name in getattr_bindings.items() |
| 335 | + if local_name in callable_guards and local_name in called_locals |
| 336 | + } |
| 337 | + |
| 338 | + |
| 339 | +def _collect_dynamic_getattr_names(tree: ast.AST) -> set[str]: |
| 340 | + names: set[str] = set() |
| 341 | + for scope in _iter_runtime_callable_scopes(tree): |
| 342 | + names.update(_dynamic_getattr_names_from_scope(scope)) |
| 343 | + return names |
| 344 | + |
| 345 | + |
| 346 | +def _local_export_qualname( |
| 347 | + *, |
| 348 | + module_name: str, |
| 349 | + exported_name: str, |
| 350 | + functions_by_name: dict[str, str], |
| 351 | + classes_by_name: dict[str, str], |
| 352 | +) -> str | None: |
| 353 | + local_qualname = functions_by_name.get(exported_name) |
| 354 | + if local_qualname is None: |
| 355 | + local_qualname = classes_by_name.get(exported_name) |
| 356 | + if local_qualname is None: |
| 357 | + return None |
| 358 | + return f"{module_name}:{local_qualname}" |
224 | 359 |
|
225 | 360 |
|
226 | 361 | def _collect_import_from_node( |
@@ -472,13 +607,19 @@ def _resolve_referenced_qualnames( |
472 | 607 | resolved.add(local_method_qualname) |
473 | 608 |
|
474 | 609 | for exported_name in state.exported_names: |
475 | | - local_qualname = top_level_function_by_name.get(exported_name) |
476 | | - if local_qualname is not None: |
477 | | - resolved.add(f"{module_name}:{local_qualname}") |
| 610 | + local_export_qualname = _local_export_qualname( |
| 611 | + module_name=module_name, |
| 612 | + exported_name=exported_name, |
| 613 | + functions_by_name=top_level_function_by_name, |
| 614 | + classes_by_name=top_level_class_by_name, |
| 615 | + ) |
| 616 | + if local_export_qualname is not None: |
| 617 | + resolved.add(local_export_qualname) |
478 | 618 | continue |
479 | | - class_qualname = top_level_class_by_name.get(exported_name) |
480 | | - if class_qualname is not None: |
481 | | - resolved.add(f"{module_name}:{class_qualname}") |
| 619 | + resolved.update(state.imported_symbol_bindings.get(exported_name, ())) |
| 620 | + if state.has_module_getattr: |
| 621 | + for module_path in state.lazy_export_bindings.get(exported_name, ()): |
| 622 | + resolved.add(f"{module_path}:{exported_name}") |
482 | 623 |
|
483 | 624 | return frozenset(resolved) |
484 | 625 |
|
@@ -524,6 +665,8 @@ def _collect_module_walk_data( |
524 | 665 | ) |
525 | 666 | elif collect_referenced_names: |
526 | 667 | _collect_load_reference_node(node=node, state=state) |
| 668 | + if collect_referenced_names: |
| 669 | + state.referenced_names.update(_collect_dynamic_getattr_names(tree)) |
527 | 670 |
|
528 | 671 | deps_sorted = tuple( |
529 | 672 | sorted( |
|
0 commit comments