diff --git a/packages/reflex-base/src/reflex_base/utils/pyi_generator.py b/packages/reflex-base/src/reflex_base/utils/pyi_generator.py index ca5409aa49a..3479f10f321 100644 --- a/packages/reflex-base/src/reflex_base/utils/pyi_generator.py +++ b/packages/reflex-base/src/reflex_base/utils/pyi_generator.py @@ -121,6 +121,20 @@ def _safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]) -> bool: "reflex_base.style": ["Style"], "reflex_base.vars.base": ["Var"], } +# These pre-0.9 imports might be present in the file and should be removed since the pyi generator will handle them separately. +EXCLUDED_IMPORTS = { + "reflex.components.core.breakpoints": ["Breakpoints"], + "reflex.event": [ + "EventChain", + "EventHandler", + "EventSpec", + "EventType", + "KeyInputInfo", + "PointerEventInfo", + ], + "reflex.style": ["Style"], + "reflex.vars.base": ["Var"], +} def _walk_files(path: str | Path): @@ -1074,10 +1088,19 @@ def visit_Import( Returns: The modified import node(s). """ + # Drop any imports in the EXCLUDED_IMPORTS mapping since they are supplied by DEFAULT_IMPORTS. + if isinstance(node, ast.ImportFrom) and node.module in EXCLUDED_IMPORTS: + node.names = [ + alias + for alias in node.names + if alias.name not in EXCLUDED_IMPORTS[node.module] + ] if not self.inserted_imports: self.inserted_imports = True default_imports = _generate_imports(self.typing_imports) - return [*default_imports, node] + return [*default_imports, *([node] if node.names else ())] + if not node.names: + return [] return node def visit_ImportFrom( @@ -1330,21 +1353,22 @@ def _write_pyi_file(module_path: Path, source: str) -> str: } -def _rewrite_component_import(module: str) -> str: +def _rewrite_component_import(module: str, is_reflex_package: bool) -> str: """Rewrite a lazy-loader module path to the correct absolute package import. Args: module: The module path from ``_SUBMOD_ATTRS`` (e.g. ``"components.radix.themes.base"``). + is_reflex_package: Whether the module is part of the Reflex package. Returns: An absolute import path (``"reflex_components_radix.themes.base"``) for moved components, or a relative path (``".components.component"``) for everything else. """ - if module == "components": + if is_reflex_package and module == "components": # "components": ["el", "radix", ...] — these are re-exported submodules. # Can't map to a single package, but the pyi generator handles each attr individually. return "reflex_components_core" - if module.startswith("components."): + if is_reflex_package and module.startswith("components."): rest = module[len("components.") :] # Try progressively deeper matches (e.g. "datadisplay.code" before "datadisplay"). parts = rest.split(".") @@ -1357,7 +1381,7 @@ def _rewrite_component_import(module: str) -> str: return f".{module}" -def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST): +def _get_init_lazy_imports(mod: ModuleType, new_tree: ast.AST): # retrieve the _SUBMODULES and _SUBMOD_ATTRS from an init file if present. sub_mods: set[str] | None = getattr(mod, "_SUBMODULES", None) sub_mod_attrs: dict[str, list[str | tuple[str, str]]] | None = getattr( @@ -1375,6 +1399,8 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST): sub_mods_imports = [f"from . import {mod}" for mod in sorted(sub_mods)] sub_mods_imports.append("") + is_reflex_package = bool(mod.__name__.partition(".")[0] == "reflex") + if sub_mod_attrs: flattened_sub_mod_attrs = { imported: module @@ -1385,7 +1411,8 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST): for imported, module in flattened_sub_mod_attrs.items(): # For "components": ["el", "radix", ...], resolve each attr to its package. if ( - module == "components" + is_reflex_package + and module == "components" and isinstance(imported, str) and imported in _COMPONENT_SUBPACKAGE_TARGETS ): @@ -1393,7 +1420,7 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST): sub_mod_attrs_imports.append(f"import {target} as {imported}") continue - rewritten = _rewrite_component_import(module) + rewritten = _rewrite_component_import(module, is_reflex_package) if isinstance(imported, tuple): suffix = ( (imported[0] + " as " + imported[1])