Skip to content

Commit 1e78f2c

Browse files
feat: add experimental memo decorator for JS-level component and function memoization (#6192)
* feat: add experimental memo decorator for JS-level component and function memoization Introduce rx.experimental.memo (rx._x.memo) that allows memoizing components and plain functions at the JavaScript level. Supports component memos with typed props (including children and rest props via RestProp), and function memos that emit raw JS. Updates the compiler pipeline to handle both memo kinds alongside existing CustomComponent memos, and refactors signature rendering to use DestructuredArg. * fix: prevent memo name collisions and compile-time mutation of stored components Add registry helpers that detect duplicate exported names across memo kinds and raise on collision. Deepcopy the component before applying styles during compilation so the stored definition stays clean. Simplify the function wrappers .call to alias the wrapper itself. * test: clear old memos when testing. * test: cleanup * pyi: update hashes * fix: camelCase rest-prop keys in memo function bindings and clean up memo internals Convert remaining_props keys to camelCase in _bind_function_runtime_args so rest props (e.g. class_name → className) match the component memo behavior. Also make MemoParam kw_only, return a tuple from get_props instead of a dict, and remove unnecessary monkeypatch boilerplate from the integration test fixture. * refactor: replace memo wrapper closures with proper callable classes Replace _create_function_wrapper and _create_component_wrapper closures with _ExperimentalMemoFunctionWrapper and _ExperimentalMemoComponentWrapper classes, eliminating object.__setattr__ hacks for call/partial/_as_var in favor of real methods. * updated hashes * fix: accept Var[Component] return from component-returning memos Extract _normalize_component_return to wrap Var[Component] values in Bare.create, allowing memos that return rx.cond or other component-typed vars to be registered as component memos. Add a cond overload for (Any, Var[Component], Var[Component]) -> Component. * refactor: create per-memo component subclasses with tag set at class level Replace instance-level self.tag assignment with cached dynamically created ExperimentalMemoComponent subclasses via _get_experimental_memo_component_class, so the tag is a class-level attribute rather than set in _post_init. --------- Co-authored-by: Masen Furer <m_github@0x26.net>
1 parent 931ac2c commit 1e78f2c

17 files changed

Lines changed: 1859 additions & 19 deletions

File tree

pyi_hashes.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"reflex/__init__.pyi": "a0266c47111e9af7f340186013c7a31e",
2+
"reflex/__init__.pyi": "9823934f0e3fca36228004a6fbb1d8df",
33
"reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb",
44
"reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a",
55
"reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1",
@@ -118,5 +118,6 @@
118118
"reflex/components/recharts/general.pyi": "9abf71810a5405fd45b13804c0a7fd1a",
119119
"reflex/components/recharts/polar.pyi": "ea4743e8903365ba95bc4b653c47cc4a",
120120
"reflex/components/recharts/recharts.pyi": "b3d93d085d51053bbb8f65326f34a299",
121-
"reflex/components/sonner/toast.pyi": "636050fcc919f8ab0903c30dceaa18f1"
121+
"reflex/components/sonner/toast.pyi": "636050fcc919f8ab0903c30dceaa18f1",
122+
"reflex/experimental/memo.pyi": "78b1968972194785f72eab32476bc61d"
122123
}

reflex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@
347347
"utils.imports": ["ImportDict", "ImportVar"],
348348
"utils.misc": ["run_in_thread"],
349349
"utils.serializers": ["serializer"],
350-
"vars": ["Var", "field", "Field"],
350+
"vars": ["Var", "field", "Field", "RestProp"],
351351
}
352352

353353
_SUBMODULES: set[str] = {

reflex/app.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
get_hydrate_event,
8383
noop,
8484
)
85+
from reflex.experimental.memo import EXPERIMENTAL_MEMOS
8586
from reflex.istate.manager import StateModificationContext
8687
from reflex.istate.proxy import StateProxy
8788
from reflex.page import DECORATED_PAGES
@@ -1280,7 +1281,10 @@ def memoized_toast_provider():
12801281
memo_components_output,
12811282
memo_components_result,
12821283
memo_components_imports,
1283-
) = compiler.compile_memo_components(dict.fromkeys(CUSTOM_COMPONENTS.values()))
1284+
) = compiler.compile_memo_components(
1285+
dict.fromkeys(CUSTOM_COMPONENTS.values()),
1286+
tuple(EXPERIMENTAL_MEMOS.values()),
1287+
)
12841288
compile_results.append((memo_components_output, memo_components_result))
12851289
all_imports.update(memo_components_imports)
12861290
progress.advance(task)

reflex/compiler/compiler.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
from reflex.constants.compiler import PageNames, ResetStylesheet
2323
from reflex.constants.state import FIELD_MARKER
2424
from reflex.environment import environment
25+
from reflex.experimental.memo import (
26+
ExperimentalMemoComponentDefinition,
27+
ExperimentalMemoDefinition,
28+
ExperimentalMemoFunctionDefinition,
29+
)
2530
from reflex.state import BaseState
2631
from reflex.style import SYSTEM_COLOR_MODE
2732
from reflex.utils import console, path_ops
@@ -339,28 +344,46 @@ def _compile_component(component: Component | StatefulComponent) -> str:
339344

340345
def _compile_memo_components(
341346
components: Iterable[CustomComponent],
347+
experimental_memos: Iterable[ExperimentalMemoDefinition] = (),
342348
) -> tuple[str, dict[str, list[ImportVar]]]:
343349
"""Compile the components.
344350
345351
Args:
346352
components: The components to compile.
353+
experimental_memos: The experimental memos to compile.
347354
348355
Returns:
349356
The compiled components.
350357
"""
351-
imports = {
352-
"react": [ImportVar(tag="memo")],
353-
f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
354-
}
358+
imports: dict[str, list[ImportVar]] = {}
355359
component_renders = []
360+
function_renders = []
356361

357362
# Compile each component.
358363
for component in components:
359364
component_render, component_imports = utils.compile_custom_component(component)
360365
component_renders.append(component_render)
361366
imports = utils.merge_imports(imports, component_imports)
362367

363-
_apply_common_imports(imports)
368+
for memo in experimental_memos:
369+
if isinstance(memo, ExperimentalMemoComponentDefinition):
370+
memo_render, memo_imports = utils.compile_experimental_component_memo(memo)
371+
component_renders.append(memo_render)
372+
imports = utils.merge_imports(imports, memo_imports)
373+
elif isinstance(memo, ExperimentalMemoFunctionDefinition):
374+
memo_render, memo_imports = utils.compile_experimental_function_memo(memo)
375+
function_renders.append(memo_render)
376+
imports = utils.merge_imports(imports, memo_imports)
377+
378+
if component_renders:
379+
imports = utils.merge_imports(
380+
{
381+
"react": [ImportVar(tag="memo")],
382+
f"$/{constants.Dirs.STATE_PATH}": [ImportVar(tag="isTrue")],
383+
},
384+
imports,
385+
)
386+
_apply_common_imports(imports)
364387

365388
dynamic_imports = {
366389
comp_import: None
@@ -380,6 +403,7 @@ def _compile_memo_components(
380403
templates.memo_components_template(
381404
imports=utils.compile_imports(imports),
382405
components=component_renders,
406+
functions=function_renders,
383407
dynamic_imports=sorted(dynamic_imports),
384408
custom_codes=custom_codes,
385409
),
@@ -573,11 +597,13 @@ def compile_page(path: str, component: BaseComponent) -> tuple[str, str]:
573597

574598
def compile_memo_components(
575599
components: Iterable[CustomComponent],
600+
experimental_memos: Iterable[ExperimentalMemoDefinition] = (),
576601
) -> tuple[str, str, dict[str, list[ImportVar]]]:
577602
"""Compile the custom components.
578603
579604
Args:
580605
components: The custom components to compile.
606+
experimental_memos: The experimental memos to compile.
581607
582608
Returns:
583609
The path and code of the compiled components.
@@ -586,7 +612,7 @@ def compile_memo_components(
586612
output_path = utils.get_components_path()
587613

588614
# Compile the components.
589-
code, imports = _compile_memo_components(components)
615+
code, imports = _compile_memo_components(components, experimental_memos)
590616
return output_path, code, imports
591617

592618

reflex/compiler/templates.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from reflex import constants
1010
from reflex.constants import Hooks
11-
from reflex.constants.state import CAMEL_CASE_MEMO_MARKER
1211
from reflex.utils.format import format_state_name, json_dumps
1312
from reflex.vars.base import VarData
1413

@@ -661,6 +660,7 @@ def stateful_components_template(imports: list[_ImportDict], memoized_code: str)
661660
def memo_components_template(
662661
imports: list[_ImportDict],
663662
components: list[dict[str, Any]],
663+
functions: list[dict[str, Any]],
664664
dynamic_imports: Iterable[str],
665665
custom_codes: Iterable[str],
666666
) -> str:
@@ -669,6 +669,7 @@ def memo_components_template(
669669
Args:
670670
imports: List of import statements.
671671
components: List of component definitions.
672+
functions: List of function definitions.
672673
dynamic_imports: List of dynamic import statements.
673674
custom_codes: List of custom code snippets.
674675
@@ -682,21 +683,29 @@ def memo_components_template(
682683
components_code = ""
683684
for component in components:
684685
components_code += f"""
685-
export const {component["name"]} = memo(({{ {",".join([f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in component.get("props", [])])} }}) => {{
686+
export const {component["name"]} = memo(({component["signature"]}) => {{
686687
{_render_hooks(component.get("hooks", {}))}
687688
return(
688689
{_RenderUtils.render(component["render"])}
689690
)
690691
}});
691692
"""
692693

694+
functions_code = ""
695+
for function in functions:
696+
functions_code += (
697+
f"\nexport const {function['name']} = {function['function']};\n"
698+
)
699+
693700
return f"""
694701
{imports_str}
695702
696703
{dynamic_imports_str}
697704
698705
{custom_code_str}
699706
707+
{functions_code}
708+
700709
{components_code}"""
701710

702711

reflex/compiler/utils.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import concurrent.futures
7+
import copy
78
import operator
89
import traceback
910
from collections.abc import Mapping, Sequence
@@ -20,14 +21,19 @@
2021
from reflex.components.el.elements.metadata import Head, Link, Meta, Title
2122
from reflex.components.el.elements.other import Html
2223
from reflex.components.el.elements.sectioning import Body
23-
from reflex.constants.state import FIELD_MARKER
24+
from reflex.constants.state import CAMEL_CASE_MEMO_MARKER, FIELD_MARKER
25+
from reflex.experimental.memo import (
26+
ExperimentalMemoComponentDefinition,
27+
ExperimentalMemoFunctionDefinition,
28+
)
2429
from reflex.istate.storage import Cookie, LocalStorage, SessionStorage
2530
from reflex.state import BaseState, _resolve_delta
2631
from reflex.style import Style
2732
from reflex.utils import format, imports, path_ops
2833
from reflex.utils.imports import ImportVar, ParsedImportDict
2934
from reflex.utils.prerequisites import get_web_dir
3035
from reflex.vars.base import Field, Var, VarData
36+
from reflex.vars.function import DestructuredArg
3137

3238
# To re-export this function.
3339
merge_imports = imports.merge_imports
@@ -344,6 +350,9 @@ def compile_custom_component(
344350
{
345351
"name": component.tag,
346352
"props": props,
353+
"signature": DestructuredArg(
354+
fields=tuple(f"{prop}:{prop}{CAMEL_CASE_MEMO_MARKER}" for prop in props)
355+
).to_javascript(),
347356
"render": render.render(),
348357
"hooks": render._get_all_hooks(),
349358
"custom_code": render._get_all_custom_code(),
@@ -353,6 +362,104 @@ def compile_custom_component(
353362
)
354363

355364

365+
def _apply_component_style_for_compile(component: Component) -> Component:
366+
"""Apply the app style to a compiled component tree.
367+
368+
Args:
369+
component: The component tree.
370+
371+
Returns:
372+
The styled component tree.
373+
"""
374+
try:
375+
from reflex.utils.prerequisites import get_and_validate_app
376+
377+
style = get_and_validate_app().app.style
378+
except Exception:
379+
style = {}
380+
381+
component._add_style_recursive(style)
382+
return component
383+
384+
385+
def compile_experimental_component_memo(
386+
definition: ExperimentalMemoComponentDefinition,
387+
) -> tuple[dict, ParsedImportDict]:
388+
"""Compile an experimental memo component.
389+
390+
Args:
391+
definition: The component memo definition.
392+
393+
Returns:
394+
A tuple of the compiled component definition and its imports.
395+
"""
396+
render = _apply_component_style_for_compile(copy.deepcopy(definition.component))
397+
398+
imports: ParsedImportDict = {
399+
lib: fields
400+
for lib, fields in render._get_all_imports().items()
401+
if lib != f"$/{constants.Dirs.COMPONENTS_PATH}"
402+
}
403+
404+
imports.setdefault("@emotion/react", []).append(ImportVar("jsx"))
405+
406+
signature_fields = [
407+
f"{param.js_prop_name}:{param.placeholder_name}"
408+
for param in definition.params
409+
if not param.is_children and not param.is_rest
410+
]
411+
412+
if any(param.is_children for param in definition.params):
413+
signature_fields.insert(0, "children")
414+
415+
rest_param = next((param for param in definition.params if param.is_rest), None)
416+
417+
return (
418+
{
419+
"kind": "component",
420+
"name": definition.export_name,
421+
"signature": DestructuredArg(
422+
fields=tuple(signature_fields),
423+
rest=rest_param.placeholder_name if rest_param is not None else None,
424+
).to_javascript(),
425+
"render": render.render(),
426+
"hooks": render._get_all_hooks(),
427+
"custom_code": render._get_all_custom_code(),
428+
"dynamic_imports": render._get_all_dynamic_imports(),
429+
},
430+
imports,
431+
)
432+
433+
434+
def compile_experimental_function_memo(
435+
definition: ExperimentalMemoFunctionDefinition,
436+
) -> tuple[dict, ParsedImportDict]:
437+
"""Compile an experimental memo function.
438+
439+
Args:
440+
definition: The function memo definition.
441+
442+
Returns:
443+
A tuple of the compiled function definition and its imports.
444+
"""
445+
imports: ParsedImportDict = {}
446+
if var_data := definition.function._get_all_var_data():
447+
imports = {
448+
lib: list(fields)
449+
for lib, fields in dict(var_data.imports).items()
450+
if lib != f"$/{constants.Dirs.COMPONENTS_PATH}"
451+
}
452+
453+
return (
454+
{
455+
"kind": "function",
456+
"name": definition.python_name,
457+
"function": str(definition.function),
458+
},
459+
imports,
460+
)
461+
462+
356463
def create_document_root(
357464
head_components: Sequence[Component] | None = None,
358465
html_lang: str | None = None,

reflex/components/core/cond.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ def cond(condition: Any, c1: Component, c2: Any, /) -> Component: ... # pyright
9898
def cond(condition: Any, c1: Component, /) -> Component: ...
9999

100100

101+
@overload
102+
def cond(condition: Any, c1: Var[Component], c2: Var[Component], /) -> Component: ... # pyright: ignore [reportOverlappingOverload]
103+
104+
101105
@overload
102106
def cond(condition: Any, c1: Any, c2: Component, /) -> Component: ... # pyright: ignore [reportOverlappingOverload]
103107

reflex/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from . import hooks as hooks
1010
from .client_state import ClientStateVar as ClientStateVar
11+
from .memo import memo as memo
1112

1213

1314
class ExperimentalNamespace(SimpleNamespace):
@@ -58,4 +59,5 @@ def register_component_warning(component_name: str):
5859
client_state=ClientStateVar.create,
5960
hooks=hooks,
6061
code_block=code_block,
62+
memo=memo,
6163
)

0 commit comments

Comments
 (0)