Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 156 additions & 97 deletions packages/reflex-base/src/reflex_base/components/component.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,27 @@ def fix_event_triggers_for_memo(
"""Return a component whose event triggers reference memoized ``useCallback``s.

Replaces each (non-lifecycle) event-trigger value with a ``Var`` naming a
memoized ``useCallback`` wrapper. The original is never mutated — a
page-local clone is taken via ``page_context.own`` on first write.
memoized ``useCallback`` wrapper. The original is never mutated — a frozen
copy with the rewritten triggers is returned via ``copy_with``.

Args:
component: The component whose event triggers to memoize.
page_context: The active page context, used to obtain a page-local
clone before rewriting ``event_triggers``.
page_context: The active page context (unused; retained for API
compatibility with downstream callers).

Returns:
Either ``component`` (when nothing needed rewriting) or a page-local
clone with the rewritten ``event_triggers``.
Either ``component`` (when nothing needed rewriting) or a new frozen
copy with the rewritten ``event_triggers``.
"""
memo_event_triggers = tuple(get_memoized_event_triggers(component).items())
if not memo_event_triggers:
return component
owned = page_context.own(component)
owned.event_triggers = {
**component.event_triggers,
**dict(memo_event_triggers),
}
return owned
return component.copy_with(
event_triggers={
**component.event_triggers,
**dict(memo_event_triggers),
}
)


def is_snapshot_boundary(component: Component) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,6 @@ def render_component(self) -> Component:

# Set the component key.
if component.key is None:
component.key = index
component = component.copy_with(key=index)

return component
49 changes: 6 additions & 43 deletions packages/reflex-base/src/reflex_base/plugins/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from __future__ import annotations

import copy
import dataclasses
import inspect
from collections.abc import Callable, Sequence
from contextvars import ContextVar, Token
from types import TracebackType
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, TypeVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeAlias, cast

from typing_extensions import Self

Expand All @@ -32,9 +31,6 @@
)


_BaseComponentT = TypeVar("_BaseComponentT", bound=BaseComponent)


class PageDefinition(Protocol):
"""Protocol for page-like objects compiled by :class:`CompileContext`."""

Expand Down Expand Up @@ -374,8 +370,7 @@ def visit(
updated_children = list(children[:index])
updated_children.append(compiled_child)
if updated_children is not None:
current_comp = page_context.own(current_comp)
current_comp.children = updated_children
current_comp = current_comp.copy_with(children=tuple(updated_children))

if isinstance(current_comp, Component):
for prop_component in current_comp._get_components_in_props():
Expand Down Expand Up @@ -437,8 +432,7 @@ def visit(
updated_children = list(children[:index])
updated_children.append(compiled_child)
if updated_children is not None:
current_comp = page_context.own(current_comp)
current_comp.children = updated_children
current_comp = current_comp.copy_with(children=tuple(updated_children))

if isinstance(current_comp, Component):
for prop_component in current_comp._get_components_in_props():
Expand Down Expand Up @@ -549,8 +543,9 @@ def visit(
if len(compiled_children) != len(current) or any(
a is not b for a, b in zip(compiled_children, current, strict=True)
):
compiled_component = page_context.own(compiled_component)
compiled_component.children = list(compiled_children)
compiled_component = compiled_component.copy_with(
children=tuple(compiled_children)
)
return compiled_component

return visit(
Expand Down Expand Up @@ -695,38 +690,6 @@ class PageContext(BaseContext):
# the matching ``leave_component``. Non-empty iff we are inside such a
# subtree.
memoize_suppressor_stack: list[int] = dataclasses.field(default_factory=list)
# Maps both the user-owned original's ``id()`` and the clone's ``id()`` to
# the page-local clone. Lets the walker and plugins rebind children, style,
# or event_triggers on a page-local copy without mutating a user-owned
# instance that may be referenced from another route.
_owned: dict[int, BaseComponent] = dataclasses.field(default_factory=dict)
# Strong references to originals keyed by ``id()`` above. Without these,
# an original that is only reachable through ``_owned``'s int key can be
# garbage collected, and Python may recycle its ``id()`` for a fresh
# component, causing ``own()`` to hand back the wrong clone.
_owned_refs: list[BaseComponent] = dataclasses.field(default_factory=list)

def own(self, comp: _BaseComponentT) -> _BaseComponentT:
"""Return a page-local copy of ``comp``, cloning on first encounter.

Repeated calls with the same original return the same clone, so
mutations from several plugins accumulate on one instance.

Args:
comp: The component the caller is about to mutate.

Returns:
A component the caller may freely mutate without touching any
user-owned instance.
"""
existing = self._owned.get(id(comp))
if existing is not None:
return cast("_BaseComponentT", existing)
new = copy.copy(comp)
self._owned[id(comp)] = new
self._owned[id(new)] = new
self._owned_refs.append(comp)
return new

def merged_imports(self, *, collapse: bool = False) -> ParsedImportDict:
"""Return the imports accumulated for this page.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import dataclasses
from collections.abc import Iterator, Sequence
from typing import Any

Expand All @@ -13,7 +14,7 @@
from reflex_base.utils.decorator import once
from reflex_base.utils.imports import ParsedImportDict
from reflex_base.vars import BooleanVar, ObjectVar, Var
from reflex_base.vars.base import GLOBAL_CACHE, VarData
from reflex_base.vars.base import VarData
from reflex_base.vars.sequence import LiteralStringVar


Expand Down Expand Up @@ -215,22 +216,44 @@ def _add_style_recursive(
theme: The theme to add.

Returns:
The component with the style added.
A component with the style added; ``self`` if nothing changed.
"""
new_self = super()._add_style_recursive(style, theme)

are_components_touched = False

if isinstance(self.contents, Var):
for component in _components_from_var(self.contents):
if isinstance(component, Component):
component._add_style_recursive(style, theme)
are_components_touched = True

if are_components_touched:
GLOBAL_CACHE.clear()

return new_self
if not isinstance(self.contents, Var):
return new_self
var_data = self.contents._var_data
if not var_data or not var_data.components:
return new_self

rebuilt: list | None = None
for i, embedded in enumerate(var_data.components):
if not isinstance(embedded, Component):
continue
updated = embedded._add_style_recursive(style, theme)
if updated is embedded:
continue
if rebuilt is None:
rebuilt = list(var_data.components)
rebuilt[i] = updated

if rebuilt is None:
return new_self

new_var_data = VarData(
state=var_data.state,
field_name=var_data.field_name,
imports=var_data.old_school_imports(),
hooks=dict.fromkeys(var_data.hooks),
deps=list(var_data.deps),
position=var_data.position,
components=tuple(rebuilt),
)
new_contents = dataclasses.replace(
self.contents,
_var_data=new_var_data,
)
return new_self.copy_with(contents=new_contents)

def _get_vars(
self, include_children: bool = False, ignore_ids: set[int] | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def create(cls, **props) -> Component:
)

info_message = el.div(
warning_icon,
el.span(
"If you are the owner of this app, visit ",
el.a(
Expand Down Expand Up @@ -390,8 +391,6 @@ def create(cls, **props) -> Component:
background_color=color("amber", 3),
padding="0.625rem",
)
# Prepend warning icon into info_message children
info_message.children.insert(0, warning_icon)

resume_button = el.a(
el.button(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from reflex_base.components.component import Component, field
from reflex_base.constants import EventTriggers
from reflex_base.event import EventHandler, no_args_event_spec
from reflex_base.style import Style
from reflex_base.utils import format
from reflex_base.vars import VarData
from reflex_base.vars.base import Var
Expand Down Expand Up @@ -94,9 +95,13 @@ def create(cls, *children: Component, **props: Any) -> Component:
for p in cls.get_props()
if getattr(child, p, None) is not None
}
props[EventTriggers.ON_CHANGE] = child.event_triggers.pop(
EventTriggers.ON_CHANGE
)
on_change = child.event_triggers[EventTriggers.ON_CHANGE]
child_event_triggers_minus_on_change = {
k: v
for k, v in child.event_triggers.items()
if k != EventTriggers.ON_CHANGE
}
props[EventTriggers.ON_CHANGE] = on_change
props = {**props_from_child, **props}

# Carry all other child props directly via custom_attrs
Expand All @@ -114,9 +119,11 @@ def create(cls, *children: Component, **props: Any) -> Component:
debounce_input_prop_names = {
format.to_camel_case(prop) for prop in cls.get_props()
}
for colliding_key in [k for k in child.style if k in debounce_input_prop_names]:
child.style.pop(colliding_key)
props.setdefault("style", {}).update(child.style)
cleaned_child_style = Style({
k: v for k, v in child.style.items() if k not in debounce_input_prop_names
})
cleaned_child = child.copy_with(style=cleaned_child_style)
props.setdefault("style", {}).update(cleaned_child_style)
if child.class_name is not None:
props["class_name"] = f"{props.get('class_name', '')} {child.class_name}"
for prop_name in ("key", "special_props"):
Expand All @@ -142,15 +149,19 @@ def create(cls, *children: Component, **props: Any) -> Component:
)

component = super().create(**props)
component._get_style = child._get_style
component.event_triggers.update(child.event_triggers)
component.children = child.children
component._rename_props = child._rename_props # pyright: ignore[reportAttributeAccessIssue]
outer_get_all_custom_code = component._get_all_custom_code
component._get_all_custom_code = lambda: (
outer_get_all_custom_code() | (child._get_all_custom_code())
return component.copy_with(
children=child.children,
event_triggers={
**component.event_triggers,
**child_event_triggers_minus_on_change,
},
_get_style=cleaned_child._get_style,
_rename_props=cleaned_child._rename_props,
_get_all_custom_code=lambda: (
outer_get_all_custom_code() | child._get_all_custom_code()
),
)
return component

def _render(self):
return super()._render().remove_props("ref")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ def create(
)
try:
# Keep a ref to a rendered component to determine correct imports/hooks/styles.
component.children = [component._render().render_component()]
return component.copy_with(
children=(component._render().render_component(),)
)
except UntypedVarError as e:
raise UntypedVarError(
iterable,
"foreach",
"https://reflex.dev/docs/library/dynamic-rendering/foreach/",
).with_traceback(e.__traceback__) from None
return component

def _render(self) -> IterTag:
props = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,28 +399,30 @@ def create(cls, *children, **props) -> Component:
)

# The file input to use.
upload = Input.create(type="file")
upload.special_props = [
Var(
_js_expr=f"{input_props_unique_name}()",
_var_type=None,
_var_data=var_data,
)
]
upload = Input.create(type="file").copy_with(
special_props=[
Var(
_js_expr=f"{input_props_unique_name}()",
_var_type=None,
_var_data=var_data,
)
]
)

# The dropzone to use.
zone = Div.create(
upload,
*children,
**{k: v for k, v in props.items() if k not in supported_props},
).copy_with(
special_props=[
Var(
_js_expr=f"{root_props_unique_name}()",
_var_type=None,
_var_data=var_data,
)
]
)
zone.special_props = [
Var(
_js_expr=f"{root_props_unique_name}()",
_var_type=None,
_var_data=var_data,
)
]

return super().create(
zone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,17 @@ def create(cls, **props) -> WindowEventListener:

real_component = cast("WindowEventListener", super().create(**props))
memo_event_triggers = get_memoized_event_triggers(real_component)
if memo_event_triggers:
real_component.event_triggers = {
**real_component.event_triggers,
**memo_event_triggers,
}
return real_component
if not memo_event_triggers:
return real_component
return cast(
"WindowEventListener",
real_component.copy_with(
event_triggers={
**real_component.event_triggers,
**memo_event_triggers,
}
),
)

def _exclude_props(self) -> list[str]:
"""Exclude event handler props from being passed to Fragment.
Expand Down
Loading
Loading