Skip to content

Commit 046e66e

Browse files
authored
Exclude old-style imports when they will be replaced with subpackage imports (#6294)
* Exclude old-style imports when they will be replaced with subpackage imports Escape hatch for packages outside of the reflex repo to generate __init__.py stubs without assuming "components" is "reflex_components_core" * greptile feedback * better visit_Import handling to avoid duplicate imports and missing default imports * add is_reflex_package escape hatch to second condition in _rewrite_component_import
1 parent 7140079 commit 046e66e

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

packages/reflex-base/src/reflex_base/utils/pyi_generator.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ def _safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]) -> bool:
121121
"reflex_base.style": ["Style"],
122122
"reflex_base.vars.base": ["Var"],
123123
}
124+
# These pre-0.9 imports might be present in the file and should be removed since the pyi generator will handle them separately.
125+
EXCLUDED_IMPORTS = {
126+
"reflex.components.core.breakpoints": ["Breakpoints"],
127+
"reflex.event": [
128+
"EventChain",
129+
"EventHandler",
130+
"EventSpec",
131+
"EventType",
132+
"KeyInputInfo",
133+
"PointerEventInfo",
134+
],
135+
"reflex.style": ["Style"],
136+
"reflex.vars.base": ["Var"],
137+
}
124138

125139

126140
def _walk_files(path: str | Path):
@@ -1074,10 +1088,19 @@ def visit_Import(
10741088
Returns:
10751089
The modified import node(s).
10761090
"""
1091+
# Drop any imports in the EXCLUDED_IMPORTS mapping since they are supplied by DEFAULT_IMPORTS.
1092+
if isinstance(node, ast.ImportFrom) and node.module in EXCLUDED_IMPORTS:
1093+
node.names = [
1094+
alias
1095+
for alias in node.names
1096+
if alias.name not in EXCLUDED_IMPORTS[node.module]
1097+
]
10771098
if not self.inserted_imports:
10781099
self.inserted_imports = True
10791100
default_imports = _generate_imports(self.typing_imports)
1080-
return [*default_imports, node]
1101+
return [*default_imports, *([node] if node.names else ())]
1102+
if not node.names:
1103+
return []
10811104
return node
10821105

10831106
def visit_ImportFrom(
@@ -1330,21 +1353,22 @@ def _write_pyi_file(module_path: Path, source: str) -> str:
13301353
}
13311354

13321355

1333-
def _rewrite_component_import(module: str) -> str:
1356+
def _rewrite_component_import(module: str, is_reflex_package: bool) -> str:
13341357
"""Rewrite a lazy-loader module path to the correct absolute package import.
13351358
13361359
Args:
13371360
module: The module path from ``_SUBMOD_ATTRS`` (e.g. ``"components.radix.themes.base"``).
1361+
is_reflex_package: Whether the module is part of the Reflex package.
13381362
13391363
Returns:
13401364
An absolute import path (``"reflex_components_radix.themes.base"``) for moved
13411365
components, or a relative path (``".components.component"``) for everything else.
13421366
"""
1343-
if module == "components":
1367+
if is_reflex_package and module == "components":
13441368
# "components": ["el", "radix", ...] — these are re-exported submodules.
13451369
# Can't map to a single package, but the pyi generator handles each attr individually.
13461370
return "reflex_components_core"
1347-
if module.startswith("components."):
1371+
if is_reflex_package and module.startswith("components."):
13481372
rest = module[len("components.") :]
13491373
# Try progressively deeper matches (e.g. "datadisplay.code" before "datadisplay").
13501374
parts = rest.split(".")
@@ -1357,7 +1381,7 @@ def _rewrite_component_import(module: str) -> str:
13571381
return f".{module}"
13581382

13591383

1360-
def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST):
1384+
def _get_init_lazy_imports(mod: ModuleType, new_tree: ast.AST):
13611385
# retrieve the _SUBMODULES and _SUBMOD_ATTRS from an init file if present.
13621386
sub_mods: set[str] | None = getattr(mod, "_SUBMODULES", None)
13631387
sub_mod_attrs: dict[str, list[str | tuple[str, str]]] | None = getattr(
@@ -1375,6 +1399,8 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST):
13751399
sub_mods_imports = [f"from . import {mod}" for mod in sorted(sub_mods)]
13761400
sub_mods_imports.append("")
13771401

1402+
is_reflex_package = bool(mod.__name__.partition(".")[0] == "reflex")
1403+
13781404
if sub_mod_attrs:
13791405
flattened_sub_mod_attrs = {
13801406
imported: module
@@ -1385,15 +1411,16 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST):
13851411
for imported, module in flattened_sub_mod_attrs.items():
13861412
# For "components": ["el", "radix", ...], resolve each attr to its package.
13871413
if (
1388-
module == "components"
1414+
is_reflex_package
1415+
and module == "components"
13891416
and isinstance(imported, str)
13901417
and imported in _COMPONENT_SUBPACKAGE_TARGETS
13911418
):
13921419
target = _COMPONENT_SUBPACKAGE_TARGETS[imported]
13931420
sub_mod_attrs_imports.append(f"import {target} as {imported}")
13941421
continue
13951422

1396-
rewritten = _rewrite_component_import(module)
1423+
rewritten = _rewrite_component_import(module, is_reflex_package)
13971424
if isinstance(imported, tuple):
13981425
suffix = (
13991426
(imported[0] + " as " + imported[1])

0 commit comments

Comments
 (0)