Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from reflex.components.base.fragment import Fragment
from reflex.components.base.strict_mode import StrictMode
from reflex.components.component import (
CUSTOM_COMPONENTS,
Component,
ComponentStyle,
evaluate_style_namespaces,
Expand Down Expand Up @@ -1222,9 +1223,8 @@ def get_compilation_time() -> str:

progress.advance(task)

# Track imports and custom components found.
# Track imports found.
all_imports = {}
custom_components = set()

# This has to happen before compiling stateful components as that
# prevents recursive functions from reaching all components.
Expand All @@ -1235,9 +1235,6 @@ def get_compilation_time() -> str:
# Add the app wrappers from this component.
app_wrappers.update(component._get_all_app_wrap_components())

# Add the custom components from the page to the set.
custom_components |= component._get_all_custom_components()

if (toaster := self.toaster) is not None:
from reflex.components.component import memo

Expand All @@ -1255,9 +1252,6 @@ def memoized_toast_provider():
if component is not None:
app_wrappers[key] = component

for component in app_wrappers.values():
custom_components |= component._get_all_custom_components()

if self.error_boundary:
from reflex.compiler.compiler import into_component

Expand Down Expand Up @@ -1382,7 +1376,7 @@ def _submit_work(fn: Callable[..., tuple[str, str]], *args, **kwargs):
custom_components_output,
custom_components_result,
custom_components_imports,
) = compiler.compile_components(custom_components)
) = compiler.compile_components(set(CUSTOM_COMPONENTS.values()))
compile_results.append((custom_components_output, custom_components_result))
all_imports.update(custom_components_imports)

Expand Down
5 changes: 1 addition & 4 deletions reflex/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _normalize_library_name(lib: str) -> str:
"""
if lib == "react":
return "React"
return lib.replace("@", "").replace("/", "_").replace("-", "_")
return lib.replace("$/", "").replace("@", "").replace("/", "_").replace("-", "_")


def _compile_app(app_root: Component) -> str:
Expand All @@ -72,9 +72,6 @@ def _compile_app(app_root: Component) -> str:

window_libraries = [
(_normalize_library_name(name), name) for name in bundled_libraries
] + [
("utils_context", f"$/{constants.Dirs.UTILS}/context"),
("utils_state", f"$/{constants.Dirs.UTILS}/state"),
]

return templates.APP_ROOT.render(
Expand Down
96 changes: 39 additions & 57 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,32 +1647,6 @@ def _get_all_refs(self) -> set[str]:

return refs

def _get_all_custom_components(
self, seen: set[str] | None = None
) -> set[CustomComponent]:
"""Get all the custom components used by the component.

Args:
seen: The tags of the components that have already been seen.

Returns:
The set of custom components.
"""
custom_components = set()

# Store the seen components in a set to avoid infinite recursion.
if seen is None:
seen = set()
for child in self.children:
# Skip BaseComponent and StatefulComponent children.
if not isinstance(child, Component):
continue
custom_components |= child._get_all_custom_components(seen=seen)
for component in self._get_components_in_props():
if isinstance(component, Component) and component.tag is not None:
custom_components |= component._get_all_custom_components(seen=seen)
return custom_components

@property
def import_var(self):
"""The tag to import.
Expand Down Expand Up @@ -1857,37 +1831,6 @@ def get_props(cls) -> set[str]:
"""
return set()

def _get_all_custom_components(
self, seen: set[str] | None = None
) -> set[CustomComponent]:
"""Get all the custom components used by the component.

Args:
seen: The tags of the components that have already been seen.

Raises:
ValueError: If the tag is not set.

Returns:
The set of custom components.
"""
if self.tag is None:
raise ValueError("The tag must be set.")

# Store the seen components in a set to avoid infinite recursion.
if seen is None:
seen = set()
custom_components = {self} | super()._get_all_custom_components(seen=seen)

# Avoid adding the same component twice.
if self.tag not in seen:
seen.add(self.tag)
custom_components |= self.get_component(self)._get_all_custom_components(
seen=seen
)

return custom_components

@staticmethod
def _get_event_spec_from_args_spec(name: str, event: EventChain) -> Callable:
"""Get the event spec from the args spec.
Expand Down Expand Up @@ -1951,6 +1894,42 @@ def get_component(self) -> Component:
return self.component_fn(*self.get_prop_vars())


CUSTOM_COMPONENTS: dict[str, CustomComponent] = {}


def _register_custom_component(
component_fn: Callable[..., Component],
):
"""Register a custom component to be compiled.

Args:
component_fn: The function that creates the component.

Raises:
TypeError: If the tag name cannot be determined.
"""
dummy_props = {
prop: (
Var(
"",
_var_type=annotation,
)
if not types.safe_issubclass(annotation, EventHandler)
else EventSpec(handler=EventHandler(fn=lambda: []))
)
for prop, annotation in typing.get_type_hints(component_fn).items()
if prop != "return"
}
dummy_component = CustomComponent._create(
children=[],
component_fn=component_fn,
**dummy_props,
)
if dummy_component.tag is None:
raise TypeError(f"Could not determine the tag name for {component_fn!r}")
CUSTOM_COMPONENTS[dummy_component.tag] = dummy_component


def custom_component(
component_fn: Callable[..., Component],
) -> Callable[..., CustomComponent]:
Expand All @@ -1971,6 +1950,9 @@ def wrapper(*children, **props) -> CustomComponent:
children=list(children), component_fn=component_fn, **props
)

# Register this component so it can be compiled.
_register_custom_component(component_fn)

return wrapper


Expand Down
10 changes: 9 additions & 1 deletion reflex/components/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ def get_cdn_url(lib: str) -> str:
return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm"


bundled_libraries = {"react", "@radix-ui/themes", "@emotion/react", "next/link"}
bundled_libraries = {
"react",
"@radix-ui/themes",
"@emotion/react",
"next/link",
f"$/{constants.Dirs.UTILS}/context",
f"$/{constants.Dirs.UTILS}/state",
f"$/{constants.Dirs.UTILS}/components",
}


def bundle_library(component: Union["Component", str]):
Expand Down
21 changes: 0 additions & 21 deletions reflex/components/markdown/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,27 +192,6 @@ def create(cls, *children, **props) -> Component:
**props,
)

def _get_all_custom_components(
self, seen: set[str] | None = None
) -> set[CustomComponent]:
"""Get all the custom components used by the component.

Args:
seen: The tags of the components that have already been seen.

Returns:
The set of custom components.
"""
custom_components = super()._get_all_custom_components(seen=seen)

# Get the custom components for each tag.
for component in self.component_map.values():
custom_components |= component(_MOCK_ARG)._get_all_custom_components(
seen=seen
)

return custom_components

def add_imports(self) -> ImportDict | list[ImportDict]:
"""Add imports for the markdown component.

Expand Down
14 changes: 9 additions & 5 deletions tests/units/components/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import reflex as rx
from reflex.base import Base
from reflex.compiler.compiler import compile_components
from reflex.compiler.utils import compile_custom_component
from reflex.components.base.bare import Bare
from reflex.components.base.fragment import Fragment
from reflex.components.component import (
CUSTOM_COMPONENTS,
Component,
CustomComponent,
StatefulComponent,
Expand Down Expand Up @@ -877,7 +878,7 @@ def test_create_custom_component(my_component):
component = rx.memo(my_component)(prop1="test", prop2=1)
assert component.tag == "MyComponent"
assert component.get_props() == {"prop1", "prop2"}
assert component._get_all_custom_components() == {component}
assert component.tag in CUSTOM_COMPONENTS


def test_custom_component_hash(my_component):
Expand Down Expand Up @@ -1801,10 +1802,13 @@ def outer(c: Component):

# Inner is not imported directly, but it is imported by the custom component.
assert "inner" not in custom_comp._get_all_imports()
assert "outer" not in custom_comp._get_all_imports()

# The imports are only resolved during compilation.
_, _, imports_inner = compile_components(custom_comp._get_all_custom_components())
custom_comp.get_component(custom_comp)
_, imports_inner = compile_custom_component(custom_comp)
assert "inner" in imports_inner
assert "outer" not in imports_inner

outer_comp = outer(c=wrapper())

Expand All @@ -1813,8 +1817,8 @@ def outer(c: Component):
assert "other" not in outer_comp._get_all_imports()

# The imports are only resolved during compilation.
_, _, imports_outer = compile_components(outer_comp._get_all_custom_components())
assert "inner" in imports_outer
_, imports_outer = compile_custom_component(outer_comp)
assert "inner" not in imports_outer
assert "other" in imports_outer


Expand Down
Loading