Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions packages/reflex-base/src/reflex_base/utils/pyi_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ def _safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]) -> bool:
"reflex_base.style": ["Style"],
"reflex_base.vars.base": ["Var"],
}
# These pre-0.9 imports might be present in the file and should be removed since the pyi generator will handle them separately.
EXCLUDED_IMPORTS = {
"reflex.components.core.breakpoints": ["Breakpoints"],
"reflex.event": [
"EventChain",
"EventHandler",
"EventSpec",
"EventType",
"KeyInputInfo",
"PointerEventInfo",
],
"reflex.style": ["Style"],
"reflex.vars.base": ["Var"],
}


def _walk_files(path: str | Path):
Expand Down Expand Up @@ -1074,10 +1088,19 @@ def visit_Import(
Returns:
The modified import node(s).
"""
# Drop any imports in the EXCLUDED_IMPORTS mapping since they are supplied by DEFAULT_IMPORTS.
if isinstance(node, ast.ImportFrom) and node.module in EXCLUDED_IMPORTS:
node.names = [
alias
for alias in node.names
if alias.name not in EXCLUDED_IMPORTS[node.module]
]
if not self.inserted_imports:
self.inserted_imports = True
default_imports = _generate_imports(self.typing_imports)
return [*default_imports, node]
return [*default_imports, *([node] if node.names else ())]
if not node.names:
return []
return node

def visit_ImportFrom(
Expand Down Expand Up @@ -1330,21 +1353,22 @@ def _write_pyi_file(module_path: Path, source: str) -> str:
}


def _rewrite_component_import(module: str) -> str:
def _rewrite_component_import(module: str, is_reflex_package: bool) -> str:
"""Rewrite a lazy-loader module path to the correct absolute package import.

Args:
module: The module path from ``_SUBMOD_ATTRS`` (e.g. ``"components.radix.themes.base"``).
is_reflex_package: Whether the module is part of the Reflex package.

Returns:
An absolute import path (``"reflex_components_radix.themes.base"``) for moved
components, or a relative path (``".components.component"``) for everything else.
"""
if module == "components":
if is_reflex_package and module == "components":
# "components": ["el", "radix", ...] — these are re-exported submodules.
# Can't map to a single package, but the pyi generator handles each attr individually.
return "reflex_components_core"
if module.startswith("components."):
if is_reflex_package and module.startswith("components."):
rest = module[len("components.") :]
# Try progressively deeper matches (e.g. "datadisplay.code" before "datadisplay").
parts = rest.split(".")
Expand All @@ -1357,7 +1381,7 @@ def _rewrite_component_import(module: str) -> str:
return f".{module}"
Comment on lines 1367 to 1381
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 components.* subpackage lookup not guarded by is_reflex_package

The PR correctly guards the module == "components" branch with is_reflex_package, but the immediately-following if module.startswith("components."): branch on line 1371 is left unguarded. If a non-reflex package's _SUBMOD_ATTRS happens to contain a key like "components.recharts" (coincidentally matching an entry in _COMPONENT_SUBPACKAGE_TARGETS), it would be silently rewritten to "reflex_components_recharts", producing a broken import for that package.

For consistency, the same guard should wrap the startswith branch:

Suggested change
if is_reflex_package and module == "components":
# "components": ["el", "radix", ...] — these are re-exported submodules.
# Can't map to a single package, but the pyi generator handles each attr individually.
return "reflex_components_core"
if is_reflex_package and module.startswith("components."):
rest = module[len("components."):]
# Try progressively deeper matches (e.g. "datadisplay.code" before "datadisplay").
parts = rest.split(".")
for depth in range(min(len(parts), 2), 0, -1):
key = ".".join(parts[:depth])
target = _COMPONENT_SUBPACKAGE_TARGETS.get(key)
if target is not None:
remainder = ".".join(parts[depth:])
return f"{target}.{remainder}" if remainder else target
return f".{module}"



def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST):
def _get_init_lazy_imports(mod: ModuleType, new_tree: ast.AST):
# retrieve the _SUBMODULES and _SUBMOD_ATTRS from an init file if present.
sub_mods: set[str] | None = getattr(mod, "_SUBMODULES", None)
sub_mod_attrs: dict[str, list[str | tuple[str, str]]] | None = getattr(
Expand All @@ -1375,6 +1399,8 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST):
sub_mods_imports = [f"from . import {mod}" for mod in sorted(sub_mods)]
sub_mods_imports.append("")

is_reflex_package = bool(mod.__name__.partition(".")[0] == "reflex")

if sub_mod_attrs:
flattened_sub_mod_attrs = {
imported: module
Expand All @@ -1385,15 +1411,16 @@ def _get_init_lazy_imports(mod: tuple | ModuleType, new_tree: ast.AST):
for imported, module in flattened_sub_mod_attrs.items():
# For "components": ["el", "radix", ...], resolve each attr to its package.
if (
module == "components"
is_reflex_package
and module == "components"
and isinstance(imported, str)
and imported in _COMPONENT_SUBPACKAGE_TARGETS
):
target = _COMPONENT_SUBPACKAGE_TARGETS[imported]
sub_mod_attrs_imports.append(f"import {target} as {imported}")
continue

rewritten = _rewrite_component_import(module)
rewritten = _rewrite_component_import(module, is_reflex_package)
if isinstance(imported, tuple):
suffix = (
(imported[0] + " as " + imported[1])
Expand Down
Loading