diff --git a/packages/reflex-core/src/reflex_core/environment.py b/packages/reflex-core/src/reflex_core/environment.py index a747cd21ba1..f005dd41ff3 100644 --- a/packages/reflex-core/src/reflex_core/environment.py +++ b/packages/reflex-core/src/reflex_core/environment.py @@ -2,14 +2,11 @@ from __future__ import annotations -import concurrent.futures import dataclasses import enum import importlib -import multiprocessing import os -import platform -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import lru_cache from pathlib import Path from typing import ( @@ -529,97 +526,6 @@ class PerformanceMode(enum.Enum): OFF = "off" -class ExecutorType(enum.Enum): - """Executor for compiling the frontend.""" - - THREAD = "thread" - PROCESS = "process" - MAIN_THREAD = "main_thread" - - @classmethod - def get_executor_from_environment(cls): - """Get the executor based on the environment variables. - - Returns: - The executor. - """ - from reflex_core.utils import console - - executor_type = environment.REFLEX_COMPILE_EXECUTOR.get() - - reflex_compile_processes = environment.REFLEX_COMPILE_PROCESSES.get() - reflex_compile_threads = environment.REFLEX_COMPILE_THREADS.get() - # By default, use the main thread. Unless the user has specified a different executor. - # Using a process pool is much faster, but not supported on all platforms. It's gated behind a flag. - if executor_type is None: - if ( - platform.system() not in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - console.warn("Multiprocessing is only supported on Linux and MacOS.") - - if ( - platform.system() in ("Linux", "Darwin") - and reflex_compile_processes is not None - ): - if reflex_compile_processes == 0: - console.warn( - "Number of processes must be greater than 0. If you want to use the default number of processes, set REFLEX_COMPILE_EXECUTOR to 'process'. Defaulting to None." - ) - reflex_compile_processes = None - elif reflex_compile_processes < 0: - console.warn( - "Number of processes must be greater than 0. Defaulting to None." - ) - reflex_compile_processes = None - executor_type = ExecutorType.PROCESS - elif reflex_compile_threads is not None: - if reflex_compile_threads == 0: - console.warn( - "Number of threads must be greater than 0. If you want to use the default number of threads, set REFLEX_COMPILE_EXECUTOR to 'thread'. Defaulting to None." - ) - reflex_compile_threads = None - elif reflex_compile_threads < 0: - console.warn( - "Number of threads must be greater than 0. Defaulting to None." - ) - reflex_compile_threads = None - executor_type = ExecutorType.THREAD - else: - executor_type = ExecutorType.MAIN_THREAD - - match executor_type: - case ExecutorType.PROCESS: - executor = concurrent.futures.ProcessPoolExecutor( - max_workers=reflex_compile_processes, - mp_context=multiprocessing.get_context("fork"), - ) - case ExecutorType.THREAD: - executor = concurrent.futures.ThreadPoolExecutor( - max_workers=reflex_compile_threads - ) - case ExecutorType.MAIN_THREAD: - FUTURE_RESULT_TYPE = TypeVar("FUTURE_RESULT_TYPE") - - class MainThreadExecutor: - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - def submit( - self, fn: Callable[..., FUTURE_RESULT_TYPE], *args, **kwargs - ) -> concurrent.futures.Future[FUTURE_RESULT_TYPE]: - future_job = concurrent.futures.Future() - future_job.set_result(fn(*args, **kwargs)) - return future_job - - executor = MainThreadExecutor() - - return executor - - class EnvironmentVariables: """Environment variables class to instantiate environment variables.""" @@ -660,14 +566,6 @@ class EnvironmentVariables: Path(constants.Dirs.UPLOADED_FILES) ) - REFLEX_COMPILE_EXECUTOR: EnvVar[ExecutorType | None] = env_var(None) - - # Whether to use separate processes to compile the frontend and how many. If not set, defaults to thread executor. - REFLEX_COMPILE_PROCESSES: EnvVar[int | None] = env_var(None) - - # Whether to use separate threads to compile the frontend and how many. Defaults to `min(32, os.cpu_count() + 4)`. - REFLEX_COMPILE_THREADS: EnvVar[int | None] = env_var(None) - # The directory to store reflex dependencies. REFLEX_DIR: EnvVar[Path] = env_var(constants.Reflex.DIR) diff --git a/pyproject.toml b/pyproject.toml index 4bfb76a3608..2ac4db66942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,7 +230,7 @@ asyncio_mode = "auto" [tool.codespell] skip = "docs/*,*.html,examples/*, *.pyi, poetry.lock, uv.lock" -ignore-words-list = "te, TreeE, selectin" +ignore-words-list = "te, TreeE, selectin, asend" [tool.coverage.run] diff --git a/reflex/app.py b/reflex/app.py index 375e869d288..cd57a8d00b4 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -3,14 +3,11 @@ from __future__ import annotations import asyncio -import concurrent.futures import contextlib -import copy import dataclasses import functools import inspect import json -import operator import sys import time import traceback @@ -20,20 +17,15 @@ AsyncIterator, Callable, Coroutine, + Iterable, Mapping, Sequence, ) -from datetime import datetime -from itertools import chain -from pathlib import Path -from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ParamSpec +from typing import TYPE_CHECKING, Any -from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.error_boundary import ErrorBoundary from reflex_components_core.base.fragment import Fragment -from reflex_components_core.base.strict_mode import StrictMode from reflex_components_core.core.banner import ( backend_disabled, connection_pulser, @@ -44,14 +36,9 @@ from reflex_components_radix import themes from reflex_components_sonner.toast import toast from reflex_core import constants -from reflex_core.components.component import ( - CUSTOM_COMPONENTS, - Component, - ComponentStyle, - evaluate_style_namespaces, -) +from reflex_core.components.component import Component, ComponentStyle from reflex_core.config import get_config -from reflex_core.environment import ExecutorType, environment +from reflex_core.environment import environment from reflex_core.event import ( _EVENT_FIELDS, Event, @@ -64,7 +51,6 @@ from reflex_core.utils import console from reflex_core.utils.imports import ImportVar from reflex_core.utils.types import ASGIApp, Message, Receive, Scope, Send -from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp from socketio import AsyncNamespace, AsyncServer from starlette.applications import Starlette @@ -79,13 +65,7 @@ from reflex.admin import AdminDash from reflex.app_mixins import AppMixin, LifespanMixin, MiddlewareMixin from reflex.compiler import compiler -from reflex.compiler import utils as compiler_utils -from reflex.compiler.compiler import ( - ExecutorSafeFunctions, - compile_theme, - readable_name_from_component, -) -from reflex.experimental.memo import EXPERIMENTAL_MEMOS +from reflex.compiler.compiler import readable_name_from_component from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES @@ -101,18 +81,8 @@ StateUpdate, _split_substate_key, _substate_key, - all_base_state_classes, - code_uses_state_contexts, -) -from reflex.utils import ( - codespaces, - exceptions, - format, - frontend_skeleton, - js_runtimes, - path_ops, - prerequisites, ) +from reflex.utils import codespaces, exceptions, format, js_runtimes, prerequisites from reflex.utils.exec import ( get_compile_context, is_prod_mode, @@ -279,9 +249,6 @@ def merged_with(self, other: UnevaluatedPage) -> UnevaluatedPage: ) -P = ParamSpec("P") - - @dataclasses.dataclass() class App(MiddlewareMixin, LifespanMixin): """The main Reflex app that encapsulates the backend and frontend. @@ -841,19 +808,7 @@ def _compile_page(self, route: str, save_page: bool = True): route: The route of the page to compile. save_page: If True, the compiled page is saved to self._pages. """ - n_states_before = len(all_base_state_classes) - component = compiler.compile_unevaluated_page( - route, self._unevaluated_pages[route], self.style, self.theme - ) - - # Indicate that evaluating this page creates one or more state classes. - if len(all_base_state_classes) > n_states_before: - self._stateful_pages[route] = None - - # Add the page. - self._check_routes_conflict(route) - if save_page: - self._pages[route] = component + compiler._compile_page_from_app(self, route, save_page=save_page) @functools.cached_property def router(self) -> Callable[[str], str | None]: @@ -955,7 +910,7 @@ def _setup_admin_dash(self): admin.mount_to(self._api) - def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]): + def _get_frontend_packages(self, imports: Mapping[str, Iterable[ImportVar]]): """Gets the frontend packages to be installed and filters out the unnecessary ones. Args: @@ -990,24 +945,6 @@ def _get_frontend_packages(self, imports: dict[str, set[ImportVar]]): page_imports.update(filtered_frontend_packages) js_runtimes.install_frontend_packages(page_imports, get_config()) - def _app_root(self, app_wrappers: dict[tuple[int, str], Component]) -> Component: - for component in tuple(app_wrappers.values()): - app_wrappers.update(component._get_all_app_wrap_components()) - order = sorted(app_wrappers, key=operator.itemgetter(0), reverse=True) - root = copy.deepcopy(app_wrappers[order[0]]) - - def reducer(parent: Component, key: tuple[int, str]) -> Component: - child = copy.deepcopy(app_wrappers[key]) - parent.children.append(child) - return child - - functools.reduce( - lambda parent, key: reducer(parent, key), - order[1:], - root, - ) - return root - def _should_compile(self) -> bool: """Check if the app should be compiled. @@ -1123,391 +1060,13 @@ def _compile( ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined. FileNotFoundError: When a plugin requires a file that does not exist. """ - from reflex_core.utils.exceptions import ReflexRuntimeError - - self._apply_decorated_pages() - - self._pages = {} - - def get_compilation_time() -> str: - return str(datetime.now().time()).split(".")[0] - - should_compile = self._should_compile() - backend_dir = prerequisites.get_backend_dir() - if not dry_run and not should_compile and backend_dir.exists(): - stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES - if stateful_pages_marker.exists(): - with stateful_pages_marker.open("r") as f: - stateful_pages = json.load(f) - for route in stateful_pages: - console.debug(f"BE Evaluating stateful page: {route}") - self._compile_page(route, save_page=False) - self._add_optional_endpoints() - return - - # Render a default 404 page if the user didn't supply one - if constants.Page404.SLUG not in self._unevaluated_pages: - self.add_page(route=constants.Page404.SLUG) - - # Fix up the style. - self.style = evaluate_style_namespaces(self.style) - - # Add the app wrappers. - app_wrappers: dict[tuple[int, str], Component] = { - # Default app wrap component renders {children} - (0, "AppWrap"): AppWrap.create() - } - - if self.theme is not None: - # If a theme component was provided, wrap the app with it - app_wrappers[20, "Theme"] = self.theme - - # Get the env mode. - config = get_config() - - if config.react_strict_mode: - app_wrappers[200, "StrictMode"] = StrictMode.create() - - if not should_compile and not dry_run: - with console.timing("Evaluate Pages (Backend)"): - for route in self._unevaluated_pages: - console.debug(f"Evaluating page: {route}") - self._compile_page(route, save_page=should_compile) - - # Save the pages which created new states at eval time. - self._write_stateful_pages_marker() - - # Add the optional endpoints (_upload) - self._add_optional_endpoints() - - return - - # Create a progress bar. - progress = ( - Progress( - *Progress.get_default_columns()[:-1], - MofNCompleteColumn(), - TimeElapsedColumn(), - ) - if use_rich - else console.PoorProgress() - ) - - # try to be somewhat accurate - but still not 100% - adhoc_steps_without_executor = 7 - fixed_pages_within_executor = 4 - plugin_count = len(config.plugins) - progress.start() - task = progress.add_task( - f"[{get_compilation_time()}] Compiling:", - total=len(self._unevaluated_pages) - + ((len(self._unevaluated_pages) + len(self._pages)) * 3) - + fixed_pages_within_executor - + adhoc_steps_without_executor - + plugin_count, - ) - - with console.timing("Evaluate Pages (Frontend)"): - performance_metrics: list[tuple[str, float]] = [] - for route in self._unevaluated_pages: - console.debug(f"Evaluating page: {route}") - start = timer() - self._compile_page(route, save_page=should_compile) - end = timer() - performance_metrics.append((route, end - start)) - progress.advance(task) - console.debug( - "Slowest pages:\n" - + "\n".join( - f"{route}: {time * 1000:.1f}ms" - for route, time in sorted( - performance_metrics, key=operator.itemgetter(1), reverse=True - )[:10] - ) - ) - # Save the pages which created new states at eval time. - self._write_stateful_pages_marker() - - # Add the optional endpoints (_upload) - self._add_optional_endpoints() - - self._validate_var_dependencies() - self._setup_overlay_component() - - if config.show_built_with_reflex is None: - if ( - get_compile_context() == constants.CompileContext.DEPLOY - and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] - ): - config.show_built_with_reflex = False - else: - config.show_built_with_reflex = True - - if is_prod_mode() and config.show_built_with_reflex: - self._setup_sticky_badge() - - progress.advance(task) - - # Store the compile results. - compile_results: list[tuple[str, str]] = [] - - progress.advance(task) - - # Track imports found. - all_imports = {} - - if (toaster := self.toaster) is not None: - from reflex_core.components.component import memo - - @memo - def memoized_toast_provider(): - return toaster - - toast_provider = Fragment.create(memoized_toast_provider()) - - app_wrappers[44, "ToasterProvider"] = toast_provider - - # Add the app wraps to the app. - for key, app_wrap in chain( - self.app_wraps.items(), self.extra_app_wraps.items() - ): - # If the app wrap is a callable, generate the component - component = app_wrap(self._state is not None) - if component is not None: - app_wrappers[key] = component - - # Compile custom components. - ( - memo_components_output, - memo_components_result, - memo_components_imports, - ) = compiler.compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), - tuple(EXPERIMENTAL_MEMOS.values()), - ) - compile_results.append((memo_components_output, memo_components_result)) - all_imports.update(memo_components_imports) - progress.advance(task) - - with console.timing("Collect all imports and app wraps"): - # This has to happen before compiling stateful components as that - # prevents recursive functions from reaching all components. - for component in self._pages.values(): - # Add component._get_all_imports() to all_imports. - all_imports.update(component._get_all_imports()) - - # Add the app wrappers from this component. - app_wrappers.update(component._get_all_app_wrap_components()) - - progress.advance(task) - - # Perform auto-memoization of stateful components. - with console.timing("Auto-memoize StatefulComponents"): - ( - stateful_components_path, - stateful_components_code, - page_components, - ) = compiler.compile_stateful_components( - self._pages.values(), - progress_function=lambda task=task: progress.advance(task), - ) - progress.advance(task) - - # Catch "static" apps (that do not define a rx.State subclass) which are trying to access rx.State. - if code_uses_state_contexts(stateful_components_code) and self._state is None: - msg = ( - "To access rx.State in frontend components, at least one " - "subclass of rx.State must be defined in the app." - ) - raise ReflexRuntimeError(msg) - compile_results.append((stateful_components_path, stateful_components_code)) - - progress.advance(task) - - # Compile the root document before fork. - compile_results.append( - compiler.compile_document_root( - self.head_components, - html_lang=self.html_lang, - html_custom_attrs=( - {"suppressHydrationWarning": True, **self.html_custom_attrs} - if self.html_custom_attrs - else {"suppressHydrationWarning": True} - ), - ) - ) - - progress.advance(task) - - # Copy the assets. - assets_src = Path.cwd() / constants.Dirs.APP_ASSETS - if assets_src.is_dir() and not dry_run: - with console.timing("Copy assets"): - path_ops.update_directory_tree( - src=assets_src, - dest=( - Path.cwd() / prerequisites.get_web_dir() / constants.Dirs.PUBLIC - ), - ) - - executor = ExecutorType.get_executor_from_environment() - - for route, component in zip(self._pages, page_components, strict=True): - ExecutorSafeFunctions.COMPONENTS[route] = component - - modify_files_tasks: list[tuple[str, str, Callable[[str], str]]] = [] - - with console.timing("Compile to Javascript"), executor as executor: - result_futures: list[ - concurrent.futures.Future[ - list[tuple[str, str]] | tuple[str, str] | None - ] - ] = [] - - def _submit_work( - fn: Callable[P, list[tuple[str, str]] | tuple[str, str] | None], - *args: P.args, - **kwargs: P.kwargs, - ): - f = executor.submit(fn, *args, **kwargs) - f.add_done_callback(lambda _: progress.advance(task)) - result_futures.append(f) - - # Compile the pre-compiled pages. - for route in self._pages: - _submit_work( - ExecutorSafeFunctions.compile_page, - route, - ) - - # Compile the root stylesheet with base styles. - _submit_work( - compiler.compile_root_stylesheet, self.stylesheets, self.reset_style - ) - - # Compile the theme. - _submit_work(compile_theme, self.style) - - def _submit_work_without_advancing( - fn: Callable[P, list[tuple[str, str]] | tuple[str, str] | None], - *args: P.args, - **kwargs: P.kwargs, - ): - f = executor.submit(fn, *args, **kwargs) - result_futures.append(f) - - for plugin in config.plugins: - plugin.pre_compile( - add_save_task=_submit_work_without_advancing, - add_modify_task=( - lambda *args, plugin=plugin: modify_files_tasks.append(( - plugin.__class__.__module__ + plugin.__class__.__name__, - *args, - )) - ), - unevaluated_pages=list(self._unevaluated_pages.values()), - ) - - # Wait for all compilation tasks to complete. - for future in concurrent.futures.as_completed(result_futures): - if (result := future.result()) is not None: - if isinstance(result, list): - compile_results.extend(result) - else: - compile_results.append(result) - - progress.advance(task, advance=len(config.plugins)) - - app_root = self._app_root(app_wrappers=app_wrappers) - - # Get imports from AppWrap components. - all_imports.update(app_root._get_all_imports()) - - progress.advance(task) - - # Compile the contexts. - compile_results.append( - compiler.compile_contexts(self._state, self.theme), - ) - if self.theme is not None: - # Fix #2992 by removing the top-level appearance prop - self.theme.appearance = None # pyright: ignore[reportAttributeAccessIssue] - progress.advance(task) - - # Compile the app root. - compile_results.append( - compiler.compile_app(app_root), - ) - progress.advance(task) - - progress.stop() - - if dry_run: - return - - # Install frontend packages. - with console.timing("Install Frontend Packages"): - self._get_frontend_packages(all_imports) - - # Setup the react-router.config.js - frontend_skeleton.update_react_router_config( + compiler.compile_app( + self, prerender_routes=prerender_routes, + dry_run=dry_run, + use_rich=use_rich, ) - if is_prod_mode(): - # Empty the .web pages directory. - compiler.purge_web_pages_dir() - else: - # In dev mode, delete removed pages and update existing pages. - keep_files = [Path(output_path) for output_path, _ in compile_results] - for p in Path( - prerequisites.get_web_dir() - / constants.Dirs.PAGES - / constants.Dirs.ROUTES - ).rglob("*"): - if p.is_file() and p not in keep_files: - # Remove pages that are no longer in the app. - p.unlink() - - output_mapping: dict[Path, str] = {} - for output_path, code in compile_results: - path = compiler_utils.resolve_path_of_web_dir(output_path) - if path in output_mapping: - console.warn( - f"Path {path} has two different outputs. The first one will be used." - ) - else: - output_mapping[path] = code - - for plugin in config.plugins: - for static_file_path, content in plugin.get_static_assets(): - path = compiler_utils.resolve_path_of_web_dir(static_file_path) - if path in output_mapping: - console.warn( - f"Plugin {plugin.__class__.__name__} is trying to write to {path} but it already exists. The plugin file will be ignored." - ) - else: - output_mapping[path] = ( - content.decode("utf-8") - if isinstance(content, bytes) - else content - ) - - for plugin_name, file_path, modify_fn in modify_files_tasks: - path = compiler_utils.resolve_path_of_web_dir(file_path) - file_content = output_mapping.get(path) - if file_content is None: - if path.exists(): - file_content = path.read_text() - else: - msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." - raise FileNotFoundError(msg) - output_mapping[path] = modify_fn(file_content) - - with console.timing("Write to Disk"): - for output_path, code in output_mapping.items(): - compiler_utils.write_file(output_path, code) - def _write_stateful_pages_marker(self): """Write list of routes that create dynamic states for the backend to use later.""" if self._state is not None: diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index d17a8baf012..d28c0a8d96b 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -2,20 +2,28 @@ from __future__ import annotations +import asyncio +import copy +import json +import operator import sys from collections.abc import Callable, Iterable, Sequence from inspect import getmodule from pathlib import Path from typing import TYPE_CHECKING, Any +from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.fragment import Fragment +from reflex_components_core.base.strict_mode import StrictMode from reflex_core import constants from reflex_core.components.component import ( + CUSTOM_COMPONENTS, BaseComponent, Component, ComponentStyle, CustomComponent, StatefulComponent, + evaluate_style_namespaces, ) from reflex_core.config import get_config from reflex_core.constants.compiler import PageNames, ResetStylesheet @@ -26,19 +34,40 @@ from reflex_core.utils.format import to_title_case from reflex_core.utils.imports import ImportVar, ParsedImportDict from reflex_core.vars.base import LiteralVar, Var +from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from reflex.compiler import templates, utils +from reflex.compiler.plugins import CompileContext, CompilerHooks, default_page_plugins from reflex.experimental.memo import ( + EXPERIMENTAL_MEMOS, ExperimentalMemoComponentDefinition, ExperimentalMemoDefinition, ExperimentalMemoFunctionDefinition, ) -from reflex.state import BaseState -from reflex.utils import console, path_ops -from reflex.utils.exec import is_prod_mode +from reflex.state import BaseState, code_uses_state_contexts +from reflex.utils import console, frontend_skeleton, path_ops, prerequisites +from reflex.utils.exec import get_compile_context, is_prod_mode from reflex.utils.prerequisites import get_web_dir +def _set_progress_total( + progress: Progress | console.PoorProgress, + task: Any, + total: int, +) -> None: + """Update a task total for either rich or fallback progress bars.""" + if isinstance(progress, Progress): + progress.update(task, total=total) + return + + if task not in progress.tasks: + return + + previous_total = progress.tasks[task]["total"] + progress.tasks[task]["total"] = total + progress.total += total - previous_total + + def _apply_common_imports( imports: dict[str, list[ImportVar]], ): @@ -521,7 +550,7 @@ def compile_document_root( return output_path, code -def compile_app(app_root: Component) -> tuple[str, str]: +def compile_app_root(app_root: Component) -> tuple[str, str]: """Compile the app root. Args: @@ -660,8 +689,388 @@ def purge_web_pages_dir(): ) +def _compile_page_from_app( + app: App, + route: str, + *, + save_page: bool = True, +) -> Component: + """Evaluate a single app page for compatibility call sites. + + Args: + app: The app containing the page definition. + route: The route to evaluate. + save_page: Whether to store the evaluated component on ``app._pages``. + + Returns: + The evaluated page component. + """ + from reflex.state import all_base_state_classes + + n_states_before = len(all_base_state_classes) + component = compile_unevaluated_page( + route, + app._unevaluated_pages[route], + app.style, + app.theme, + ) + + if len(all_base_state_classes) > n_states_before: + app._stateful_pages[route] = None + + app._check_routes_conflict(route) + if save_page: + app._pages[route] = component + return component + + +def _resolve_app_wrap_components( + app: App, + page_app_wrap_components: dict[tuple[int, str], Component], +) -> dict[tuple[int, str], Component]: + """Build the full app-wrap registry for compilation. + + Returns: + The merged app-wrap component mapping. + """ + config = get_config() + + app_wrappers: dict[tuple[int, str], Component] = { + (0, "AppWrap"): AppWrap.create(), + **page_app_wrap_components, + } + + if app.theme is not None: + app_wrappers[20, "Theme"] = app.theme + + if config.react_strict_mode: + app_wrappers[200, "StrictMode"] = StrictMode.create() + + if (toaster := app.toaster) is not None: + from reflex_core.components.component import memo + + @memo + def memoized_toast_provider(): + return toaster + + app_wrappers[44, "ToasterProvider"] = Fragment.create(memoized_toast_provider()) + + for wrap_mapping in (app.app_wraps, app.extra_app_wraps): + for key, app_wrap in wrap_mapping.items(): + component = app_wrap(app._state is not None) + if component is not None: + app_wrappers[key] = component + + return app_wrappers + + +def _build_app_root(app_wrappers: dict[tuple[int, str], Component]) -> Component: + """Create the wrapped app root component from ordered wrappers. + + Returns: + The wrapped app root component. + """ + for component in tuple(app_wrappers.values()): + app_wrappers.update(component._get_all_app_wrap_components()) + + order = sorted(app_wrappers, key=operator.itemgetter(0), reverse=True) + root = copy.deepcopy(app_wrappers[order[0]]) + parent = root + for key in order[1:]: + child = copy.deepcopy(app_wrappers[key]) + parent.children.append(child) + parent = child + return root + + +def compile_app( + app: App, + *, + prerender_routes: bool = False, + dry_run: bool = False, + use_rich: bool = True, +) -> None: + """Compile an app using the compiler plugin pipeline.""" + from reflex_core.utils.exceptions import ReflexRuntimeError + + app._apply_decorated_pages() + app._pages = {} + + should_compile = app._should_compile() + backend_dir = prerequisites.get_backend_dir() + if not dry_run and not should_compile and backend_dir.exists(): + stateful_pages_marker = backend_dir / constants.Dirs.STATEFUL_PAGES + if stateful_pages_marker.exists(): + with stateful_pages_marker.open("r") as file: + stateful_pages = json.load(file) + for route in stateful_pages: + console.debug(f"BE Evaluating stateful page: {route}") + _compile_page_from_app(app, route, save_page=False) + app._add_optional_endpoints() + return + + if constants.Page404.SLUG not in app._unevaluated_pages: + app.add_page(route=constants.Page404.SLUG) + + app.style = evaluate_style_namespaces(app.style) + config = get_config() + + if not should_compile and not dry_run: + with console.timing("Evaluate Pages (Backend)"): + for route in app._unevaluated_pages: + console.debug(f"Evaluating page: {route}") + _compile_page_from_app(app, route, save_page=False) + + app._write_stateful_pages_marker() + app._add_optional_endpoints() + return + + progress = ( + Progress( + *Progress.get_default_columns()[:-1], + MofNCompleteColumn(), + TimeElapsedColumn(), + ) + if use_rich + else console.PoorProgress() + ) + fixed_steps = 7 + base_total = (len(app._unevaluated_pages) * 2) + fixed_steps + len(config.plugins) + progress.start() + task = progress.add_task("Compiling:", total=base_total) + + compile_ctx = CompileContext( + app=app, + pages=list(app._unevaluated_pages.values()), + hooks=CompilerHooks( + plugins=default_page_plugins(style=app.style, theme=app.theme) + ), + ) + + async def _compile_with_context() -> None: + async with compile_ctx: + await compile_ctx.compile( + apply_overlay=True, + evaluate_progress=lambda: progress.advance(task), + render_progress=lambda: progress.advance(task), + ) + + with console.timing("Compile pages"): + asyncio.run(_compile_with_context()) + + for route, page_ctx in compile_ctx.compiled_pages.items(): + app._check_routes_conflict(route) + if not isinstance(page_ctx.root_component, Component): + msg = ( + f"Compiled page {route!r} root must be a Component before it can " + "be registered on the app." + ) + raise TypeError(msg) + app._pages[route] = page_ctx.root_component + + app._stateful_pages.update(compile_ctx.stateful_routes) + app._write_stateful_pages_marker() + app._add_optional_endpoints() + app._validate_var_dependencies() + + if config.show_built_with_reflex is None: + if ( + get_compile_context() == constants.CompileContext.DEPLOY + and prerequisites.get_user_tier() in ["pro", "team", "enterprise"] + ): + config.show_built_with_reflex = False + else: + config.show_built_with_reflex = True + + if is_prod_mode() and config.show_built_with_reflex: + app._setup_sticky_badge() + + progress.advance(task) + + compile_results = [ + (page_ctx.output_path, page_ctx.output_code) + for page_ctx in compile_ctx.compiled_pages.values() + if page_ctx.output_path is not None and page_ctx.output_code is not None + ] + all_imports = compile_ctx.all_imports + + ( + memo_components_output, + memo_components_result, + memo_components_imports, + ) = compile_memo_components( + dict.fromkeys(CUSTOM_COMPONENTS.values()), + tuple(EXPERIMENTAL_MEMOS.values()), + ) + compile_results.append((memo_components_output, memo_components_result)) + all_imports = utils.merge_imports(all_imports, memo_components_imports) + progress.advance(task) + + if ( + code_uses_state_contexts(compile_ctx.stateful_components_code) + and app._state is None + ): + msg = ( + "To access rx.State in frontend components, at least one " + "subclass of rx.State must be defined in the app." + ) + raise ReflexRuntimeError(msg) + if compile_ctx.stateful_components_path is not None: + compile_results.append(( + compile_ctx.stateful_components_path, + compile_ctx.stateful_components_code, + )) + progress.advance(task) + + app_wrappers = _resolve_app_wrap_components(app, compile_ctx.app_wrap_components) + app_root = _build_app_root(app_wrappers) + all_imports = utils.merge_imports(all_imports, app_root._get_all_imports()) + + compile_results.append( + compile_document_root( + app.head_components, + html_lang=app.html_lang, + html_custom_attrs=( + {"suppressHydrationWarning": True, **app.html_custom_attrs} + if app.html_custom_attrs + else {"suppressHydrationWarning": True} + ), + ) + ) + progress.advance(task) + + assets_src = Path.cwd() / constants.Dirs.APP_ASSETS + if assets_src.is_dir() and not dry_run: + with console.timing("Copy assets"): + path_ops.update_directory_tree( + src=assets_src, + dest=Path.cwd() / prerequisites.get_web_dir() / constants.Dirs.PUBLIC, + ) + + save_tasks: list[ + tuple[ + Callable[..., list[tuple[str, str]] | tuple[str, str] | None], + tuple[Any, ...], + dict[str, Any], + ] + ] = [] + modify_files_tasks: list[tuple[str, str, Callable[[str], str]]] = [] + + def add_save_task( + task_fn: Callable[..., list[tuple[str, str]] | tuple[str, str] | None], + /, + *args: Any, + **kwargs: Any, + ) -> None: + save_tasks.append((task_fn, args, kwargs)) + + for plugin in config.plugins: + plugin.pre_compile( + add_save_task=add_save_task, + add_modify_task=lambda *args, plugin=plugin: modify_files_tasks.append(( + plugin.__class__.__module__ + plugin.__class__.__name__, + *args, + )), + unevaluated_pages=list(app._unevaluated_pages.values()), + ) + + if save_tasks: + _set_progress_total(progress, task, base_total + len(save_tasks)) + + progress.advance(task, advance=len(config.plugins)) + + compile_results.append(compile_root_stylesheet(app.stylesheets, app.reset_style)) + progress.advance(task) + + compile_results.append(compile_theme(app.style)) + progress.advance(task) + + for task_fn, args, kwargs in save_tasks: + result = task_fn(*args, **kwargs) + if result is None: + progress.advance(task) + continue + if isinstance(result, list): + compile_results.extend(result) + else: + compile_results.append(result) + progress.advance(task) + + compile_results.append( + compile_contexts( + app._state, + app.theme, + ) + ) + if app.theme is not None: + app.theme.appearance = None # pyright: ignore[reportAttributeAccessIssue] + progress.advance(task) + + compile_results.append(compile_app_root(app_root)) + progress.advance(task) + + progress.stop() + + if dry_run: + return + + with console.timing("Install Frontend Packages"): + app._get_frontend_packages(all_imports) + + frontend_skeleton.update_react_router_config( + prerender_routes=prerender_routes, + ) + + if is_prod_mode(): + purge_web_pages_dir() + else: + keep_files = [Path(output_path) for output_path, _ in compile_results] + for page_file in Path( + prerequisites.get_web_dir() / constants.Dirs.PAGES / constants.Dirs.ROUTES + ).rglob("*"): + if page_file.is_file() and page_file not in keep_files: + page_file.unlink() + + output_mapping: dict[Path, str] = {} + for output_path, code in compile_results: + path = utils.resolve_path_of_web_dir(output_path) + if path in output_mapping: + console.warn( + f"Path {path} has two different outputs. The first one will be used." + ) + else: + output_mapping[path] = code + + for plugin in config.plugins: + for static_file_path, content in plugin.get_static_assets(): + path = utils.resolve_path_of_web_dir(static_file_path) + if path in output_mapping: + console.warn( + f"Plugin {plugin.__class__.__name__} is trying to write to {path} but it already exists. The plugin file will be ignored." + ) + else: + output_mapping[path] = ( + content.decode("utf-8") if isinstance(content, bytes) else content + ) + + for plugin_name, file_path, modify_fn in modify_files_tasks: + path = utils.resolve_path_of_web_dir(file_path) + file_content = output_mapping.get(path) + if file_content is None: + if path.exists(): + file_content = path.read_text() + else: + msg = f"Plugin {plugin_name} is trying to modify {path} but it does not exist." + raise FileNotFoundError(msg) + output_mapping[path] = modify_fn(file_content) + + with console.timing("Write to Disk"): + for output_path, code in output_mapping.items(): + utils.write_file(output_path, code) + + if TYPE_CHECKING: - from reflex.app import ComponentCallable, UnevaluatedPage + from reflex.app import App, ComponentCallable, UnevaluatedPage def _into_component_once( @@ -869,84 +1278,3 @@ def compile_unevaluated_page( raise else: return component - - -class ExecutorSafeFunctions: - """Helper class to allow parallelisation of parts of the compilation process. - - This class (and its class attributes) are available at global scope. - - In a multiprocessing context (like when using a ProcessPoolExecutor), the content of this - global class is logically replicated to any FORKED process. - - How it works: - * Before the child process is forked, ensure that we stash any input data required by any future - function call in the child process. - * After the child process is forked, the child process will have a copy of the global class, which - includes the previously stashed input data. - * Any task submitted to the child process simply needs a way to communicate which input data the - requested function call requires. - - Why do we need this? Passing input data directly to child process often not possible because the input data is not picklable. - The mechanic described here removes the need to pickle the input data at all. - - Limitations: - * This can never support returning unpicklable OUTPUT data. - * Any object mutations done by the child process will not propagate back to the parent process (fork goes one way!). - - """ - - COMPONENTS: dict[str, BaseComponent] = {} - UNCOMPILED_PAGES: dict[str, UnevaluatedPage] = {} - - @classmethod - def compile_page(cls, route: str) -> tuple[str, str]: - """Compile a page. - - Args: - route: The route of the page to compile. - - Returns: - The path and code of the compiled page. - """ - return compile_page(route, cls.COMPONENTS[route]) - - @classmethod - def compile_unevaluated_page( - cls, - route: str, - style: ComponentStyle, - theme: Component | None, - ) -> tuple[str, Component, tuple[str, str]]: - """Compile an unevaluated page. - - Args: - route: The route of the page to compile. - style: The style of the page. - theme: The theme of the page. - - Returns: - The route, compiled component, and compiled page. - """ - component = compile_unevaluated_page( - route, cls.UNCOMPILED_PAGES[route], style, theme - ) - return route, component, compile_page(route, component) - - @classmethod - def compile_theme(cls, style: ComponentStyle | None) -> tuple[str, str]: - """Compile the theme. - - Args: - style: The style to compile. - - Returns: - The path and code of the compiled theme. - - Raises: - ValueError: If the style is not set. - """ - if style is None: - msg = "STYLE should be set" - raise ValueError(msg) - return compile_theme(style) diff --git a/reflex/compiler/plugins/__init__.py b/reflex/compiler/plugins/__init__.py new file mode 100644 index 00000000000..ac5db2c084f --- /dev/null +++ b/reflex/compiler/plugins/__init__.py @@ -0,0 +1,43 @@ +"""Compiler plugin foundations for single-pass page compilation.""" + +from reflex.compiler.plugins.base import ( + BaseContext, + CompileComponentYield, + CompileContext, + CompilerHooks, + CompilerPlugin, + ComponentAndChildren, + PageContext, + PageDefinition, +) +from reflex.compiler.plugins.builtin import ( + ApplyStylePlugin, + ConsolidateAppWrapPlugin, + ConsolidateCustomCodePlugin, + ConsolidateDynamicImportsPlugin, + ConsolidateHooksPlugin, + ConsolidateImportsPlugin, + ConsolidateRefsPlugin, + DefaultPagePlugin, + default_page_plugins, +) + +__all__ = [ + "ApplyStylePlugin", + "BaseContext", + "CompileComponentYield", + "CompileContext", + "CompilerHooks", + "CompilerPlugin", + "ComponentAndChildren", + "ConsolidateAppWrapPlugin", + "ConsolidateCustomCodePlugin", + "ConsolidateDynamicImportsPlugin", + "ConsolidateHooksPlugin", + "ConsolidateImportsPlugin", + "ConsolidateRefsPlugin", + "DefaultPagePlugin", + "PageContext", + "PageDefinition", + "default_page_plugins", +] diff --git a/reflex/compiler/plugins/base.py b/reflex/compiler/plugins/base.py new file mode 100644 index 00000000000..9fe6e55ff5e --- /dev/null +++ b/reflex/compiler/plugins/base.py @@ -0,0 +1,624 @@ +"""Core compiler plugin infrastructure: protocols, contexts, and dispatch.""" + +from __future__ import annotations + +import asyncio +import dataclasses +import inspect +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence +from contextvars import ContextVar, Token +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Protocol, + TypeAlias, + cast, + overload, +) + +from reflex_core.components.component import BaseComponent, Component, StatefulComponent +from reflex_core.utils.imports import ParsedImportDict, collapse_imports, merge_imports +from reflex_core.vars import VarData +from typing_extensions import Self + +if TYPE_CHECKING: + from reflex.app import App + + +class PageDefinition(Protocol): + """Protocol for page-like objects compiled by :class:`CompileContext`.""" + + @property + def route(self) -> str: + """Return the declared route for the page.""" + ... + + @property + def component(self) -> Any: + """Return the declared component or page callable.""" + ... + + +ComponentAndChildren: TypeAlias = tuple[BaseComponent, tuple[BaseComponent, ...]] +CompileComponentYield: TypeAlias = BaseComponent | ComponentAndChildren | None + + +class CompilerPlugin(Protocol): + """Protocol for compiler plugins that participate in page compilation.""" + + async def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Evaluate a page-like object into a page context. + + Args: + page_fn: The page callable or component-like object to evaluate. + page: The declared page metadata associated with ``page_fn``. + **kwargs: Additional compilation context for advanced plugins. + + Returns: + ``None`` to indicate that the plugin does not handle the page. + """ + return None + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Finalize a page context after its component tree has been traversed.""" + return + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[CompileComponentYield, ComponentAndChildren]: + """Inspect or transform a component during recursive compilation. + + Yields: + Optional replacements before and after child traversal. + """ + if False: # pragma: no cover + yield None + return + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class CompilerHooks: + """Dispatch compiler hooks across an ordered plugin chain.""" + + plugins: tuple[CompilerPlugin, ...] = () + + @staticmethod + def _get_hook_impl( + plugin: CompilerPlugin, + hook_name: str, + ) -> Callable[..., Any] | None: + """Return the concrete hook implementation for a plugin, if any. + + Plugins that inherit the default hook bodies from + :class:`CompilerPlugin` are treated as not implementing the hook and + are skipped by dispatch. + """ + plugin_impl = inspect.getattr_static(type(plugin), hook_name, None) + if plugin_impl is None: + return None + + base_impl = inspect.getattr_static(CompilerPlugin, hook_name, None) + if plugin_impl is base_impl: + return None + + return cast(Callable[..., Any], getattr(plugin, hook_name, None)) + + @overload + async def _dispatch( + self, + hook_name: str, + *args: Any, + stop_on_result: Literal[False] = False, + **kwargs: Any, + ) -> list[Any]: ... + + @overload + async def _dispatch( + self, + hook_name: str, + *args: Any, + stop_on_result: Literal[True], + **kwargs: Any, + ) -> Any | None: ... + + async def _dispatch( + self, + hook_name: str, + *args: Any, + stop_on_result: bool = False, + **kwargs: Any, + ) -> list[Any] | Any | None: + """Dispatch a coroutine hook across all plugins in registration order. + + Args: + hook_name: The plugin hook attribute to invoke. + *args: Positional arguments forwarded to the hook. + stop_on_result: Whether to return immediately on the first non-None + result instead of collecting all results. + **kwargs: Keyword arguments forwarded to the hook. + + Returns: + When ``stop_on_result`` is false, a list of hook return values in + registration order. Otherwise, the first non-None result, or + ``None`` if every plugin returned ``None``. + """ + if stop_on_result: + for plugin in self.plugins: + hook_impl = self._get_hook_impl(plugin, hook_name) + if hook_impl is None: + continue + result = await cast(Awaitable[Any], hook_impl(*args, **kwargs)) + if result is not None: + return result + return None + results: list[Any] = [] + for plugin in self.plugins: + hook_impl = self._get_hook_impl(plugin, hook_name) + if hook_impl is None: + continue + results.append(await cast(Awaitable[Any], hook_impl(*args, **kwargs))) + return results + + async def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Return the first page context produced by the plugin chain.""" + result = await self._dispatch( + "eval_page", + page_fn, + stop_on_result=True, + page=page, + **kwargs, + ) + return cast(PageContext | None, result) + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Run all ``compile_page`` hooks in plugin order.""" + await self._dispatch("compile_page", page_ctx, **kwargs) + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> BaseComponent: + """Walk a component tree once while dispatching component hooks. + + Plugins are entered in registration order before children are visited and + unwound in reverse order after the structural children have been + compiled. Component-valued props are traversed after structural children + for collection and side effects, but their transformed values are not + written back to the parent in this foundational slice. + + Args: + comp: The component to compile. + **kwargs: Additional context shared with plugins. + + Returns: + The compiled component. + """ + active_generators: list[ + AsyncGenerator[CompileComponentYield, ComponentAndChildren] + ] = [] + compiled_component = comp + structural_children: tuple[BaseComponent, ...] | None = None + + try: + for plugin in self.plugins: + hook_impl = self._get_hook_impl(plugin, "compile_component") + if hook_impl is None: + continue + generator = cast( + AsyncGenerator[CompileComponentYield, ComponentAndChildren], + hook_impl(compiled_component, **kwargs), + ) + active_generators.append(generator) + try: + replacement = await anext(generator) + except StopAsyncIteration: + replacement = None + compiled_component, structural_children = self._apply_replacement( + compiled_component, + structural_children, + replacement, + ) + + if isinstance(compiled_component, StatefulComponent): + if not compiled_component.rendered_as_shared: + compiled_component.component = cast( + Component, + await self.compile_component( + compiled_component.component, + **{ + **kwargs, + "stateful_component": compiled_component, + }, + ), + ) + compiled_children = tuple(compiled_component.children) + else: + if structural_children is None: + structural_children = tuple(compiled_component.children) + compiled_children = await self._compile_children( + structural_children, + **kwargs, + ) + + if isinstance(compiled_component, Component): + for prop_component in compiled_component._get_components_in_props(): + await self.compile_component( + prop_component, + **{ + **kwargs, + "in_prop_tree": True, + }, + ) + + for generator in reversed(active_generators): + try: + replacement = await generator.asend(( + compiled_component, + compiled_children, + )) + except StopAsyncIteration: + replacement = None + compiled_component, replacement_children = self._apply_replacement( + compiled_component, + compiled_children, + replacement, + ) + if replacement_children is not compiled_children: + compiled_children = await self._compile_children( + replacement_children, + **kwargs, + ) + + compiled_component.children = list(compiled_children) + return compiled_component + finally: + for generator in reversed(active_generators): + await generator.aclose() + + async def _compile_children( + self, + children: Sequence[BaseComponent], + **kwargs: Any, + ) -> tuple[BaseComponent, ...]: + """Compile a sequence of structural children in order. + + Args: + children: The structural children to compile. + **kwargs: Additional keyword arguments forwarded to the walker. + + Returns: + The compiled children in their original order. + """ + compiled_children = [ + await self.compile_component(child, **kwargs) for child in children + ] + return tuple(compiled_children) + + @staticmethod + def _apply_replacement( + comp: BaseComponent, + children: tuple[BaseComponent, ...] | None, + replacement: CompileComponentYield, + ) -> tuple[BaseComponent, tuple[BaseComponent, ...] | None]: + """Apply a plugin replacement to the current component state. + + Args: + comp: The current component. + children: The current structural children. + replacement: The replacement returned by a plugin hook. + + Returns: + The updated component and structural children tuple. + """ + if replacement is None: + return comp, children + if isinstance(replacement, tuple): + return replacement + return replacement, children + + +@dataclasses.dataclass(kw_only=True) +class BaseContext: + """Async context manager that exposes itself through a class-local context var.""" + + __context_var__: ClassVar[ContextVar[Self | None]] + + _attached_context_token: Token[Self | None] | None = dataclasses.field( + default=None, + init=False, + repr=False, + ) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """Initialize a dedicated context variable for each subclass.""" + super().__init_subclass__(**kwargs) + cls.__context_var__ = ContextVar(cls.__name__, default=None) + + @classmethod + def get(cls) -> Self: + """Return the active context instance for the current task.""" + context = cls.__context_var__.get() + if context is None: + msg = f"No active {cls.__name__} is attached to the current context." + raise RuntimeError(msg) + return context + + async def __aenter__(self) -> Self: + """Attach this context to the current task. + + Returns: + The attached context instance. + """ + if self._attached_context_token is not None: + msg = "Context is already attached and cannot be entered twice." + raise RuntimeError(msg) + self._attached_context_token = type(self).__context_var__.set(self) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task.""" + if self._attached_context_token is None: + return + try: + type(self).__context_var__.reset(self._attached_context_token) + finally: + self._attached_context_token = None + + def ensure_context_attached(self) -> None: + """Ensure this instance is the active context for the current task.""" + try: + current = type(self).get() + except RuntimeError as err: + msg = ( + f"{type(self).__name__} must be entered with 'async with' before " + "calling this method." + ) + raise RuntimeError(msg) from err + if current is not self: + msg = f"{type(self).__name__} is not attached to the current task context." + raise RuntimeError(msg) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class PageContext(BaseContext): + """Mutable compilation state for a single page.""" + + name: str + route: str + root_component: BaseComponent + imports: list[ParsedImportDict] = dataclasses.field(default_factory=list) + module_code: dict[str, None] = dataclasses.field(default_factory=dict) + hooks: dict[str, VarData | None] = dataclasses.field(default_factory=dict) + dynamic_imports: set[str] = dataclasses.field(default_factory=set) + refs: dict[str, None] = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + frontend_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + output_path: str | None = None + output_code: str | None = None + + def merged_imports(self, *, collapse: bool = False) -> ParsedImportDict: + """Return the imports accumulated for this page. + + Args: + collapse: Whether to deduplicate import vars before returning. + + Returns: + The merged import dictionary. + """ + imports = merge_imports(*self.imports) if self.imports else {} + return collapse_imports(imports) if collapse else imports + + def custom_code_dict(self) -> dict[str, None]: + """Return custom-code snippets keyed like legacy collectors.""" + return dict(self.module_code) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class CompileContext(BaseContext): + """Mutable compilation state for an entire compile run.""" + + app: App | None = None + pages: Sequence[PageDefinition] + hooks: CompilerHooks = dataclasses.field(default_factory=CompilerHooks) + compiled_pages: dict[str, PageContext] = dataclasses.field(default_factory=dict) + all_imports: ParsedImportDict = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + stateful_routes: dict[str, None] = dataclasses.field(default_factory=dict) + stateful_components_path: str | None = None + stateful_components_code: str = "" + + async def compile( + self, + *, + evaluate_progress: Callable[[], None] | None = None, + render_progress: Callable[[], None] | None = None, + apply_overlay: bool = False, + **kwargs: Any, + ) -> dict[str, PageContext]: + """Compile all configured pages through the plugin pipeline. + + Args: + evaluate_progress: Optional callback invoked after each page is evaluated. + render_progress: Optional callback invoked after each page is rendered. + apply_overlay: Whether to apply the app's overlay component to pages. + **kwargs: Additional keyword arguments forwarded to plugin hooks. + + Returns: + The compiled pages keyed by route. + + Raises: + RuntimeError: If no plugin can evaluate a page, or if two compiled + pages resolve to the same route. + """ + from reflex.compiler import compiler + from reflex.state import all_base_state_classes + from reflex.utils.exec import is_prod_mode + + self.ensure_context_attached() + self.compiled_pages.clear() + self.all_imports.clear() + self.app_wrap_components.clear() + self.stateful_routes.clear() + self.stateful_components_path = None + self.stateful_components_code = "" + + overlay_component: Component | None = None + if ( + apply_overlay + and self.app is not None + and self.app.overlay_component is not None + ): + overlay_component = self.app._generate_component(self.app.overlay_component) + + for page in self.pages: + page_fn = page.component + n_states_before = len(all_base_state_classes) + page_ctx = await self.hooks.eval_page( + page_fn, + page=page, + compile_context=self, + **kwargs, + ) + if page_ctx is None: + page_name = getattr(page_fn, "__name__", repr(page_fn)) + msg = ( + f"No compiler plugin was able to evaluate page {page.route!r} " + f"({page_name})." + ) + raise RuntimeError(msg) + if page_ctx.route in self.compiled_pages: + msg = f"Duplicate compiled page route {page_ctx.route!r}." + raise RuntimeError(msg) + + if len(all_base_state_classes) > n_states_before: + self.stateful_routes[page.route] = None + + if overlay_component is not None and self.app is not None: + if not isinstance(page_ctx.root_component, Component): + msg = ( + f"Compiled page {page_ctx.route!r} root must be a Component " + "to apply the overlay." + ) + raise TypeError(msg) + page_ctx.root_component = self.app._add_overlay_to_component( + page_ctx.root_component, + overlay_component, + ) + + async with page_ctx: + page_ctx.root_component = await self.hooks.compile_component( + page_ctx.root_component, + page=page, + page_context=page_ctx, + compile_context=self, + **kwargs, + ) + await self.hooks.compile_page( + page_ctx, + page=page, + compile_context=self, + **kwargs, + ) + + page_ctx.frontend_imports = page_ctx.merged_imports(collapse=True) + self.all_imports = merge_imports( + self.all_imports, page_ctx.frontend_imports + ) + self.app_wrap_components.update(page_ctx.app_wrap_components) + self.compiled_pages[page_ctx.route] = page_ctx + if evaluate_progress is not None: + evaluate_progress() + + page_components: list[BaseComponent] = [] + for page_ctx in self.compiled_pages.values(): + page_component = StatefulComponent.compile_from(page_ctx.root_component) + if page_component is None: + page_component = page_ctx.root_component + page_ctx.root_component = page_component + page_components.append(page_component) + + self.stateful_components_path = compiler.utils.get_stateful_components_path() + self.stateful_components_code = ( + compiler._compile_stateful_components(page_components) + if is_prod_mode() + else "" + ) + + for page_ctx in self.compiled_pages.values(): + imports = collapse_imports(page_ctx.root_component._get_all_imports()) + page_ctx.imports = [imports] if imports else [] + page_ctx.dynamic_imports = ( + page_ctx.root_component._get_all_dynamic_imports() + ) + page_ctx.module_code = page_ctx.root_component._get_all_custom_code() + page_ctx.hooks = page_ctx.root_component._get_all_hooks() + page_ctx.refs = page_ctx.root_component._get_all_refs() + + async def _render_page(ctx: PageContext) -> None: + ctx.output_path, ctx.output_code = await asyncio.to_thread( + compiler.compile_page, + ctx.route, + ctx.root_component, + ) + if render_progress is not None: + render_progress() + + await asyncio.gather( + *(_render_page(ctx) for ctx in self.compiled_pages.values()) + ) + + return self.compiled_pages + + +__all__ = [ + "BaseContext", + "CompileComponentYield", + "CompileContext", + "CompilerHooks", + "CompilerPlugin", + "ComponentAndChildren", + "PageContext", + "PageDefinition", +] diff --git a/reflex/compiler/plugins/builtin.py b/reflex/compiler/plugins/builtin.py new file mode 100644 index 00000000000..e90c593d697 --- /dev/null +++ b/reflex/compiler/plugins/builtin.py @@ -0,0 +1,409 @@ +"""Built-in compiler plugins and the default plugin pipeline.""" + +from __future__ import annotations + +import dataclasses +from collections.abc import AsyncGenerator +from typing import Any + +from reflex_components_core.base.fragment import Fragment +from reflex_core.components.component import ( + BaseComponent, + Component, + ComponentStyle, + StatefulComponent, +) +from reflex_core.config import get_config +from reflex_core.utils.format import make_default_page_title +from reflex_core.utils.imports import collapse_imports, merge_imports +from reflex_core.vars import VarData + +from reflex.compiler import utils +from reflex.compiler.plugins.base import ( + CompilerPlugin, + ComponentAndChildren, + PageContext, + PageDefinition, +) + + +@dataclasses.dataclass(frozen=True, slots=True) +class DefaultPagePlugin(CompilerPlugin): + """Evaluate an unevaluated page into a mutable page context.""" + + async def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + """Evaluate the page function and attach legacy page metadata. + + Returns: + The initialized page context for the evaluated page. + """ + from reflex.compiler import compiler + + try: + component = compiler.into_component(page_fn) + component = Fragment.create(component) + + meta_args = { + "title": getattr(page, "title", None) + or make_default_page_title(get_config().app_name, page.route), + "image": getattr(page, "image", ""), + "meta": getattr(page, "meta", ()), + } + if (description := getattr(page, "description", None)) is not None: + meta_args["description"] = description + + utils.add_meta(component, **meta_args) + except Exception as err: + if hasattr(err, "add_note"): + err.add_note(f"Happened while evaluating page {page.route!r}") + raise + + return PageContext( + name=getattr(page_fn, "__name__", page.route), + route=page.route, + root_component=component, + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class ApplyStylePlugin(CompilerPlugin): + """Apply app-level styles in the descending phase of the walk.""" + + style: ComponentStyle | None = None + theme: Component | None = None + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + stateful_component: StatefulComponent | None = None, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Apply the non-recursive portion of ``_add_style_recursive``.""" + del kwargs, stateful_component + if self.style is not None and isinstance(comp, Component) and not in_prop_tree: + if type(comp)._add_style != Component._add_style: + msg = "Do not override _add_style directly. Use add_style instead." + raise UserWarning(msg) + + new_style = comp._add_style() + style_vars = [new_style._var_data] + + component_style = comp._get_component_style(self.style) + if component_style: + new_style.update(component_style) + style_vars.append(component_style._var_data) + + new_style.update(comp.style) + style_vars.append(comp.style._var_data) + new_style._var_data = VarData.merge(*style_vars) + comp.style = new_style + + yield + + +class ConsolidateImportsPlugin(CompilerPlugin): + """Collect per-component imports and merge them after traversal.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect non-recursive imports for structural components.""" + del kwargs + if isinstance(comp, StatefulComponent): + if comp.rendered_as_shared: + PageContext.get().imports.append(comp._get_all_imports()) + yield + return + + if not in_prop_tree and isinstance(comp, Component): + imports = comp._get_imports() + if imports: + PageContext.get().imports.append(imports) + + yield + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Collapse collected imports into a single legacy-shaped entry.""" + del kwargs + page_ctx.imports = ( + [collapse_imports(merge_imports(*page_ctx.imports))] + if page_ctx.imports + else [] + ) + + +class ConsolidateHooksPlugin(CompilerPlugin): + """Collect component hooks while skipping prop and stateful subtrees.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + stateful_component: StatefulComponent | None = None, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect the single-component hook contributions.""" + del kwargs + compiled_component, _ = yield + if ( + in_prop_tree + or stateful_component is not None + or isinstance(compiled_component, StatefulComponent) + or not isinstance(compiled_component, Component) + ): + return + + hooks = {} + hooks.update(compiled_component._get_hooks_internal()) + if (user_hooks := compiled_component._get_hooks()) is not None: + hooks[user_hooks] = None + hooks.update(compiled_component._get_added_hooks()) + PageContext.get().hooks.update(hooks) + + +@dataclasses.dataclass(frozen=True, slots=True) +class ConsolidateCustomCodePlugin(CompilerPlugin): + """Collect custom module code while preserving legacy ordering.""" + + stateful_custom_code_export: bool = False + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect custom code in legacy order for the active subtree.""" + del kwargs + page_ctx = PageContext.get() + + if isinstance(comp, StatefulComponent): + yield + if not comp.rendered_as_shared: + page_ctx.module_code[ + comp._render_stateful_code(export=self.stateful_custom_code_export) + ] = None + return + + if in_prop_tree or not isinstance(comp, Component): + yield + return + + if (custom_code := comp._get_custom_code()) is not None: + page_ctx.module_code[custom_code] = None + + for prop_component in comp._get_components_in_props(): + page_ctx.module_code.update(self._collect_prop_custom_code(prop_component)) + + for clz in comp._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(comp): + page_ctx.module_code[item] = None + + yield + + def _collect_prop_custom_code( + self, + component: BaseComponent, + ) -> dict[str, None]: + """Recursively collect custom code for a prop-component subtree. + + Returns: + The collected custom-code snippets keyed in legacy order. + """ + if isinstance(component, StatefulComponent): + if component.rendered_as_shared: + return {} + + code = self._collect_prop_custom_code(component.component) + code[ + component._render_stateful_code(export=self.stateful_custom_code_export) + ] = None + return code + + if not isinstance(component, Component): + return component._get_all_custom_code() + + code: dict[str, None] = {} + if (custom_code := component._get_custom_code()) is not None: + code[custom_code] = None + + for prop_component in component._get_components_in_props(): + code.update(self._collect_prop_custom_code(prop_component)) + + for clz in component._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(component): + code[item] = None + + for child in component.children: + code.update(self._collect_prop_custom_code(child)) + + return code + + +class ConsolidateDynamicImportsPlugin(CompilerPlugin): + """Collect dynamic imports from the active component tree.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect the current component's dynamic import.""" + del kwargs + compiled_component, _ = yield + if isinstance(compiled_component, StatefulComponent) or not isinstance( + compiled_component, Component + ): + return + if dynamic_import := compiled_component._get_dynamic_imports(): + PageContext.get().dynamic_imports.add(dynamic_import) + + +class ConsolidateRefsPlugin(CompilerPlugin): + """Collect refs from the active component tree.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect the current component ref when present.""" + del kwargs + compiled_component, _ = yield + if isinstance(compiled_component, StatefulComponent) or not isinstance( + compiled_component, Component + ): + return + if (ref := compiled_component.get_ref()) is not None: + PageContext.get().refs[ref] = None + + +class ConsolidateAppWrapPlugin(CompilerPlugin): + """Collect app-wrap components using the page walk plus wrapper recursion.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + stateful_component: StatefulComponent | None = None, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect direct wrappers and recursively expand wrapper subtrees.""" + del kwargs + compiled_component, _ = yield + if ( + in_prop_tree + or stateful_component is not None + or not isinstance(compiled_component, Component) + ): + return + + page_ctx = PageContext.get() + direct_wrappers = compiled_component._get_app_wrap_components() + if not direct_wrappers: + return + + ignore_ids = {id(wrapper) for wrapper in page_ctx.app_wrap_components.values()} + page_ctx.app_wrap_components.update(direct_wrappers) + for wrapper in direct_wrappers.values(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + page_ctx.app_wrap_components.update( + self._collect_wrapper_subtree(wrapper, ignore_ids) + ) + + def _collect_wrapper_subtree( + self, + component: Component, + ignore_ids: set[int], + ) -> dict[tuple[int, str], Component]: + """Collect app-wrap components reachable through a wrapper subtree. + + Returns: + The nested wrapper mapping discovered from the wrapper subtree. + """ + components: dict[tuple[int, str], Component] = {} + + direct_wrappers = component._get_app_wrap_components() + for key, wrapper in direct_wrappers.items(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + components[key] = wrapper + components.update(self._collect_wrapper_subtree(wrapper, ignore_ids)) + + for child in component.children: + if not isinstance(child, Component): + continue + child_id = id(child) + if child_id in ignore_ids: + continue + ignore_ids.add(child_id) + components.update(self._collect_wrapper_subtree(child, ignore_ids)) + + return components + + +def default_page_plugins( + *, + style: ComponentStyle | None = None, + theme: Component | None = None, + stateful_custom_code_export: bool = False, +) -> tuple[CompilerPlugin, ...]: + """Return the default compiler plugin ordering for page compilation.""" + return ( + DefaultPagePlugin(), + ApplyStylePlugin(style=style, theme=theme), + ConsolidateCustomCodePlugin( + stateful_custom_code_export=stateful_custom_code_export + ), + ConsolidateDynamicImportsPlugin(), + ConsolidateRefsPlugin(), + ConsolidateHooksPlugin(), + ConsolidateAppWrapPlugin(), + ConsolidateImportsPlugin(), + ) + + +__all__ = [ + "ApplyStylePlugin", + "ConsolidateAppWrapPlugin", + "ConsolidateCustomCodePlugin", + "ConsolidateDynamicImportsPlugin", + "ConsolidateHooksPlugin", + "ConsolidateImportsPlugin", + "ConsolidateRefsPlugin", + "DefaultPagePlugin", + "default_page_plugins", +] diff --git a/tests/units/compiler/test_plugins.py b/tests/units/compiler/test_plugins.py new file mode 100644 index 00000000000..91f4c4ffa9d --- /dev/null +++ b/tests/units/compiler/test_plugins.py @@ -0,0 +1,773 @@ +# ruff: noqa: D101, D102 + +import asyncio +import dataclasses +from collections.abc import AsyncGenerator, Callable +from typing import Any, cast + +import pytest +from reflex_components_core.base.fragment import Fragment +from reflex_core.components.component import ( + BaseComponent, + Component, + ComponentStyle, + field, +) +from reflex_core.utils import format as format_utils +from reflex_core.utils.imports import ImportVar, collapse_imports, merge_imports + +from reflex.compiler.plugins import ( + ApplyStylePlugin, + BaseContext, + CompileContext, + CompilerHooks, + CompilerPlugin, + ComponentAndChildren, + ConsolidateAppWrapPlugin, + ConsolidateCustomCodePlugin, + ConsolidateDynamicImportsPlugin, + ConsolidateHooksPlugin, + ConsolidateImportsPlugin, + ConsolidateRefsPlugin, + DefaultPagePlugin, + PageContext, + default_page_plugins, +) + + +@dataclasses.dataclass(slots=True) +class FakePage: + route: str + component: Callable[[], Component] + title: str | None = None + description: str | None = None + image: str = "" + meta: tuple[dict[str, Any], ...] = () + + +class StubCompilerPlugin(CompilerPlugin): + pass + + +class WrapperComponent(Component): + tag = "WrapperComponent" + library = "wrapper-lib" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(20, "NestedWrap"): Fragment.create()} + + +class RootComponent(Component): + tag = "RootComponent" + library = "root-lib" + + slot: Component | None = field(default=None) + + def add_style(self) -> dict[str, Any] | None: + return {"display": "flex"} + + def add_custom_code(self) -> list[str]: + return ["const rootAddedCode = 1;"] + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(10, "Wrap"): WrapperComponent.create()} + + +class ChildComponent(Component): + tag = "ChildComponent" + library = "child-lib" + + def add_style(self) -> dict[str, Any] | None: + return {"align_items": "center"} + + def add_custom_code(self) -> list[str]: + return ["const childAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const childCustomCode = 1;" + + def _get_hooks(self) -> str | None: + return "const childHook = useChildHook();" + + +class PropComponent(Component): + tag = "PropComponent" + library = "prop-lib" + + def add_custom_code(self) -> list[str]: + return ["const propAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const propCustomCode = 1;" + + def _get_dynamic_imports(self) -> str | None: + return "dynamic(() => import('prop-lib'))" + + def _get_hooks(self) -> str | None: + return "const propHook = usePropHook();" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(15, "PropWrap"): Fragment.create()} + + +async def collect_page_context( + component: BaseComponent, + *, + plugins: tuple[Any, ...], +) -> PageContext: + page_ctx = PageContext( + name="page", + route="/page", + root_component=component, + ) + hooks = CompilerHooks(plugins=plugins) + + async with page_ctx: + page_ctx.root_component = await hooks.compile_component(page_ctx.root_component) + await hooks.compile_page(page_ctx) + + return page_ctx + + +def create_component_tree() -> RootComponent: + return RootComponent.create( + ChildComponent.create(id="child-id", style={"color": "red"}), + slot=PropComponent.create(id="prop-id", style={"opacity": "0.5"}), + style={"margin": "0"}, + ) + + +def page_style() -> ComponentStyle: + return { + RootComponent: {"padding": "1rem"}, + ChildComponent: {"font_size": "12px"}, + PropComponent: {"border": "1px solid green"}, + } + + +class EvalPagePlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> PageContext: + component = page_fn() if callable(page_fn) else page_fn + if not isinstance(component, BaseComponent): + msg = f"Expected a BaseComponent, got {type(component).__name__}." + raise TypeError(msg) + name = getattr(page_fn, "__name__", page.route) + return PageContext( + name=name, + route=page.route, + root_component=component, + ) + + +class CollectPageDataPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + compiled_component, _children = yield + if isinstance(compiled_component, Component): + page_ctx = PageContext.get() + imports = compiled_component._get_imports() + if imports: + page_ctx.imports.append(imports) + page_ctx.hooks.update(compiled_component._get_hooks_internal()) + if hooks := compiled_component._get_hooks(): + page_ctx.hooks[hooks] = None + page_ctx.hooks.update(compiled_component._get_added_hooks()) + if module_code := compiled_component._get_custom_code(): + page_ctx.module_code[module_code] = None + if dynamic_import := compiled_component._get_dynamic_imports(): + page_ctx.dynamic_imports.add(dynamic_import) + if ref := compiled_component.get_ref(): + page_ctx.refs[ref] = None + page_ctx.app_wrap_components.update( + compiled_component._get_app_wrap_components() + ) + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + page_ctx.imports = ( + [collapse_imports(merge_imports(*page_ctx.imports))] + if page_ctx.imports + else [] + ) + + +@pytest.mark.asyncio +async def test_eval_page_uses_first_non_none_result() -> None: + calls: list[str] = [] + page = FakePage(route="/demo", component=lambda: Fragment.create()) + + class NoMatchPlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> None: + del page_fn, page, kwargs + calls.append("no-match") + return + + class MatchPlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> PageContext: + calls.append("match") + return PageContext( + name="page", + route=page.route, + root_component=page_fn(), + ) + + class UnreachablePlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> PageContext: + del page_fn, page, kwargs + calls.append("unreachable") + msg = "eval_page should stop at the first page context" + raise AssertionError(msg) + + hooks = CompilerHooks(plugins=(NoMatchPlugin(), MatchPlugin(), UnreachablePlugin())) + + page_ctx = await hooks.eval_page(page.component, page=page, compile_context=None) + + assert page_ctx is not None + assert page_ctx.route == "/demo" + assert calls == ["no-match", "match"] + + +@pytest.mark.asyncio +async def test_compile_page_runs_plugins_in_registration_order() -> None: + calls: list[str] = [] + page_ctx = PageContext( + name="page", + route="/ordered", + root_component=Fragment.create(), + ) + + class FirstPlugin(StubCompilerPlugin): + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + calls.append("first") + + class SecondPlugin(StubCompilerPlugin): + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + calls.append("second") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + + await hooks.compile_page(page_ctx, compile_context=None) + + assert calls == ["first", "second"] + + +@pytest.mark.asyncio +async def test_compile_page_skips_inherited_protocol_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + page_ctx = PageContext( + name="page", + route="/ordered", + root_component=Fragment.create(), + ) + calls: list[str] = [] + + async def fail_compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + del self, page_ctx, kwargs + await asyncio.sleep(0) + msg = "Inherited protocol compile_page hook should be skipped." + raise AssertionError(msg) + + monkeypatch.setattr(CompilerPlugin, "compile_page", fail_compile_page) + + class ProtocolOnlyPlugin(CompilerPlugin): + pass + + class RealPlugin(StubCompilerPlugin): + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + calls.append("real") + + hooks = CompilerHooks(plugins=(ProtocolOnlyPlugin(), RealPlugin())) + + await hooks.compile_page(page_ctx, compile_context=None) + + assert calls == ["real"] + + +@pytest.mark.asyncio +async def test_compile_component_orders_pre_and_post_by_plugin() -> None: + events: list[str] = [] + root = RootComponent.create() + + class FirstPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + events.append("first:pre") + yield + events.append("first:post") + + class SecondPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + events.append("second:pre") + yield + events.append("second:post") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + + compiled_root = await hooks.compile_component(root) + + assert compiled_root is root + assert events == ["first:pre", "second:pre", "second:post", "first:post"] + + +@pytest.mark.asyncio +async def test_compile_component_skips_inherited_protocol_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + events: list[str] = [] + root = RootComponent.create() + + async def fail_compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + del self, comp, kwargs + await asyncio.sleep(0) + msg = "Inherited protocol compile_component hook should be skipped." + raise AssertionError(msg) + if False: # pragma: no cover + yield None + + monkeypatch.setattr( + CompilerPlugin, + "compile_component", + fail_compile_component, + ) + + class ProtocolOnlyPlugin(CompilerPlugin): + pass + + class RealPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + events.append("real:pre") + yield + events.append("real:post") + + hooks = CompilerHooks(plugins=(ProtocolOnlyPlugin(), RealPlugin())) + + compiled_root = await hooks.compile_component(root) + + assert compiled_root is root + assert events == ["real:pre", "real:post"] + + +@pytest.mark.asyncio +async def test_compile_component_traverses_children_before_prop_components() -> None: + visited: list[str] = [] + root = RootComponent.create( + ChildComponent.create(), + slot=PropComponent.create(), + ) + + class VisitPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + if isinstance(comp, Component): + visited.append(comp.tag or type(comp).__name__) + yield + + hooks = CompilerHooks(plugins=(VisitPlugin(),)) + await hooks.compile_component(root) + + assert visited == ["RootComponent", "ChildComponent", "PropComponent"] + + +@pytest.mark.asyncio +async def test_context_lifecycle_and_cleanup() -> None: + compile_ctx = CompileContext(pages=[], hooks=CompilerHooks()) + page_ctx = PageContext( + name="page", + route="/ctx", + root_component=Fragment.create(), + ) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + with pytest.raises(RuntimeError, match="must be entered with 'async with'"): + compile_ctx.ensure_context_attached() + + async with compile_ctx: + assert CompileContext.get() is compile_ctx + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + async with page_ctx: + assert CompileContext.get() is compile_ctx + assert PageContext.get() is page_ctx + page_ctx.ensure_context_attached() + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + assert CompileContext.get() is compile_ctx + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + with pytest.raises(ValueError, match="boom"): + async with compile_ctx: + msg = "boom" + raise ValueError(msg) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + +def test_page_context_default_factories_are_isolated() -> None: + page_ctx_a = PageContext( + name="a", + route="/a", + root_component=Fragment.create(), + ) + page_ctx_b = PageContext( + name="b", + route="/b", + root_component=Fragment.create(), + ) + + page_ctx_a.imports.append({"lib-a": [ImportVar(tag="ThingA")]}) + page_ctx_a.module_code["const a = 1;"] = None + page_ctx_a.hooks["hookA"] = None + page_ctx_a.dynamic_imports.add("dynamic-a") + page_ctx_a.refs["refA"] = None + page_ctx_a.app_wrap_components[1, "WrapA"] = Fragment.create() + + assert page_ctx_b.imports == [] + assert page_ctx_b.module_code == {} + assert page_ctx_b.hooks == {} + assert page_ctx_b.dynamic_imports == set() + assert page_ctx_b.refs == {} + assert page_ctx_b.app_wrap_components == {} + + +def test_page_context_helpers_preserve_accumulated_values() -> None: + page_ctx = PageContext( + name="page", + route="/page", + root_component=Fragment.create(), + ) + page_ctx.imports.extend([ + {"lib-a": [ImportVar(tag="ThingA")]}, + {"lib-a": [ImportVar(tag="ThingB")], "lib-b": [ImportVar(tag="ThingC")]}, + ]) + page_ctx.module_code["const first = 1;"] = None + page_ctx.module_code["const second = 2;"] = None + + assert page_ctx.merged_imports() == merge_imports(*page_ctx.imports) + assert page_ctx.merged_imports(collapse=True) == collapse_imports( + merge_imports(*page_ctx.imports) + ) + assert list(page_ctx.custom_code_dict()) == [ + "const first = 1;", + "const second = 2;", + ] + + +def test_base_context_subclasses_initialize_distinct_context_vars() -> None: + class DynamicContext(BaseContext): + pass + + class AnotherDynamicContext(BaseContext): + pass + + assert DynamicContext.__context_var__ is not AnotherDynamicContext.__context_var__ + + +@pytest.mark.asyncio +async def test_apply_style_plugin_matches_legacy_recursive_behavior() -> None: + legacy_component = create_component_tree() + plugin_component = create_component_tree() + style = page_style() + + legacy_component._add_style_recursive(style) + page_ctx = await collect_page_context( + plugin_component, + plugins=(ApplyStylePlugin(style=style),), + ) + compiled_root = cast(RootComponent, page_ctx.root_component) + assert compiled_root.slot is not None + assert legacy_component.slot is not None + + assert compiled_root.render() == legacy_component.render() + assert compiled_root.slot.render() == legacy_component.slot.render() + + +@pytest.mark.asyncio +async def test_consolidate_imports_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateImportsPlugin(),), + ) + + assert page_ctx.merged_imports(collapse=True) == root._get_all_imports( + collapse=True + ) + + +@pytest.mark.asyncio +async def test_consolidate_hooks_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateHooksPlugin(),), + ) + + assert page_ctx.hooks == root._get_all_hooks() + assert "const propHook = usePropHook();" not in page_ctx.hooks + + +@pytest.mark.asyncio +async def test_consolidate_custom_code_plugin_matches_legacy_recursive_collector() -> ( + None +): + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateCustomCodePlugin(),), + ) + + assert page_ctx.custom_code_dict() == root._get_all_custom_code() + assert list(page_ctx.custom_code_dict()) == list(root._get_all_custom_code()) + + +@pytest.mark.asyncio +async def test_consolidate_dynamic_imports_plugin_matches_legacy_recursive_collector() -> ( + None +): + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateDynamicImportsPlugin(),), + ) + + assert page_ctx.dynamic_imports == root._get_all_dynamic_imports() + + +@pytest.mark.asyncio +async def test_consolidate_refs_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateRefsPlugin(),), + ) + + assert page_ctx.refs == root._get_all_refs() + + +@pytest.mark.asyncio +async def test_consolidate_app_wrap_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateAppWrapPlugin(),), + ) + + assert ( + page_ctx.app_wrap_components.keys() + == root._get_all_app_wrap_components().keys() + ) + assert (15, "PropWrap") not in page_ctx.app_wrap_components + assert (20, "NestedWrap") in page_ctx.app_wrap_components + + +@pytest.mark.asyncio +async def test_default_page_plugin_evaluates_page_like_legacy_compile_path() -> None: + page = FakePage( + route="/demo", + component=create_component_tree, + title="Demo", + description="Demo page", + image="demo.png", + meta=({"name": "robots", "content": "index,follow"},), + ) + hooks = CompilerHooks(plugins=(DefaultPagePlugin(),)) + + page_ctx = await hooks.eval_page(page.component, page=page) + + assert page_ctx is not None + assert page_ctx.route == "/demo" + assert page_ctx.name == "create_component_tree" + assert any(child.tag == "title" for child in page_ctx.root_component.children) + assert any(child.tag == "meta" for child in page_ctx.root_component.children) + + +@pytest.mark.asyncio +async def test_default_plugin_pipeline_matches_legacy_page_context_data() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=default_page_plugins(), + ) + + assert page_ctx.merged_imports(collapse=True) == root._get_all_imports( + collapse=True + ) + assert page_ctx.hooks == root._get_all_hooks() + assert page_ctx.custom_code_dict() == root._get_all_custom_code() + assert page_ctx.dynamic_imports == root._get_all_dynamic_imports() + assert page_ctx.refs == root._get_all_refs() + assert ( + page_ctx.app_wrap_components.keys() + == root._get_all_app_wrap_components().keys() + ) + + +@pytest.mark.asyncio +async def test_compile_context_compiles_pages_and_accumulates_page_data() -> None: + page = FakePage( + route="/demo", + component=lambda: RootComponent.create( + ChildComponent.create(id="child-id"), + slot=PropComponent.create(), + ), + ) + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks( + plugins=(EvalPagePlugin(), CollectPageDataPlugin()), + ), + ) + + async with compile_ctx: + compiled_pages = await compile_ctx.compile() + + assert compiled_pages is compile_ctx.compiled_pages + assert list(compiled_pages) == ["/demo"] + + page_ctx = compiled_pages["/demo"] + assert page_ctx.name == "" + assert page_ctx.route == "/demo" + assert page_ctx.imports + assert set(page_ctx.frontend_imports) >= { + "root-lib", + "child-lib", + "prop-lib", + "react", + } + assert page_ctx.output_path is not None + assert page_ctx.output_code is not None + assert "RootComponent" in page_ctx.output_code + assert page_ctx.module_code == { + "const propCustomCode = 1;": None, + "const propAddedCode = 1;": None, + "const rootAddedCode = 1;": None, + "const childCustomCode = 1;": None, + "const childAddedCode = 1;": None, + } + assert page_ctx.dynamic_imports == {"dynamic(() => import('prop-lib'))"} + assert any("useChildHook" in hook for hook in page_ctx.hooks) + assert any("useRef" in hook for hook in page_ctx.hooks) + assert page_ctx.refs == {format_utils.format_ref("child-id"): None} + assert (10, "Wrap") in page_ctx.app_wrap_components + assert (15, "PropWrap") in page_ctx.app_wrap_components + + +@pytest.mark.asyncio +async def test_compile_context_rejects_duplicate_routes() -> None: + pages = [ + FakePage(route="/duplicate", component=lambda: Fragment.create()), + FakePage(route="/duplicate", component=lambda: Fragment.create()), + ] + compile_ctx = CompileContext( + pages=pages, + hooks=CompilerHooks(plugins=(EvalPagePlugin(),)), + ) + + async with compile_ctx: + with pytest.raises(RuntimeError, match="Duplicate compiled page route"): + await compile_ctx.compile() + + +@pytest.mark.asyncio +async def test_compile_context_requires_attached_context() -> None: + compile_ctx = CompileContext( + pages=[], + hooks=CompilerHooks(), + ) + + with pytest.raises(RuntimeError, match="must be entered with 'async with'"): + await compile_ctx.compile() diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index d5fee452c5a..5f5ba11019a 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -306,9 +306,16 @@ async def test_oplock_contention_queue( state_manager_2._oplock_enabled = True modify_started = asyncio.Event() - modify_2_started = asyncio.Event() + contenders_started = asyncio.Event() modify_1_continue = asyncio.Event() modify_2_continue = asyncio.Event() + contender_started_count = 0 + + def mark_contender_started() -> None: + nonlocal contender_started_count + contender_started_count += 1 + if contender_started_count == 2: + contenders_started.set() async def modify_1(): async with state_manager_redis.modify_state( @@ -321,7 +328,7 @@ async def modify_1(): async def modify_2(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -331,7 +338,7 @@ async def modify_2(): async def modify_3(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -343,7 +350,7 @@ async def modify_3(): task_2 = asyncio.create_task(modify_2()) task_3 = asyncio.create_task(modify_3()) - await modify_2_started.wait() + await contenders_started.wait() # Let modify 1 complete modify_1_continue.set() @@ -407,9 +414,16 @@ async def test_oplock_contention_no_lease( state_manager_3._oplock_enabled = True modify_started = asyncio.Event() - modify_2_started = asyncio.Event() + contenders_started = asyncio.Event() modify_1_continue = asyncio.Event() modify_2_continue = asyncio.Event() + contender_started_count = 0 + + def mark_contender_started() -> None: + nonlocal contender_started_count + contender_started_count += 1 + if contender_started_count == 2: + contenders_started.set() async def modify_1(): async with state_manager_redis.modify_state( @@ -422,7 +436,7 @@ async def modify_1(): async def modify_2(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -432,7 +446,7 @@ async def modify_2(): async def modify_3(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_3.modify_state( _substate_key(token, root_state), ) as new_state: @@ -444,7 +458,7 @@ async def modify_3(): task_2 = asyncio.create_task(modify_2()) task_3 = asyncio.create_task(modify_3()) - await modify_2_started.wait() + await contenders_started.wait() # Let modify 1 complete modify_1_continue.set() diff --git a/tests/units/test_app.py b/tests/units/test_app.py index d34cb93283d..c1dc12e606b 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -23,6 +23,7 @@ from reflex_core.components.component import Component from reflex_core.constants.state import FIELD_MARKER from reflex_core.event import Event +from reflex_core.plugins import Plugin from reflex_core.style import Style from reflex_core.utils import console, exceptions, format from reflex_core.vars.base import computed_var @@ -2033,6 +2034,65 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: yield app, web_dir +def test_compile_executes_plugin_save_and_modify_tasks_sequentially( + compilable_app: tuple[App, Path], + mocker, +): + """Test plugin pre-compile tasks run sequentially and modify outputs.""" + events: list[str] = [] + + class OrderedPlugin(Plugin): + def pre_compile(self, **context): + def save_first(): + events.append("first") + return "plugin/ordered.txt", "alpha" + + def save_second(): + events.append("second") + return "plugin/second.txt", "beta" + + context["add_save_task"](save_first) + context["add_save_task"](save_second) + context["add_modify_task"]( + "plugin/ordered.txt", + lambda content: content + "-omega", + ) + + conf = rx.Config(app_name="testing", plugins=[OrderedPlugin()]) + mocker.patch("reflex_core.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + app.add_page(rx.box("Index"), route="index") + + app._compile() + + assert events == ["first", "second"] + assert (web_dir / "plugin" / "ordered.txt").read_text() == "alpha-omega" + assert (web_dir / "plugin" / "second.txt").read_text() == "beta" + + +def test_compile_keeps_first_duplicate_save_task_output( + compilable_app: tuple[App, Path], + mocker, +): + """Test duplicate save-task outputs keep the first generated file.""" + + class DuplicateOutputPlugin(Plugin): + def pre_compile(self, **context): + context["add_save_task"](lambda: ("plugin/duplicate.txt", "first")) + context["add_save_task"](lambda: ("plugin/duplicate.txt", "second")) + + conf = rx.Config(app_name="testing", plugins=[DuplicateOutputPlugin()]) + mocker.patch("reflex_core.config._get_config", return_value=conf) + app, web_dir = compilable_app + mocker.patch("reflex.utils.prerequisites.get_web_dir", return_value=web_dir) + app.add_page(rx.box("Index"), route="index") + + app._compile() + + assert (web_dir / "plugin" / "duplicate.txt").read_text() == "first" + + @pytest.mark.parametrize( "react_strict_mode", [True, False], diff --git a/tests/units/test_environment.py b/tests/units/test_environment.py index ab1b805a4d1..65369999df0 100644 --- a/tests/units/test_environment.py +++ b/tests/units/test_environment.py @@ -12,7 +12,6 @@ from reflex_core.environment import ( EnvironmentVariables, EnvVar, - ExecutorType, ExistingPath, PerformanceMode, SequenceOptions, @@ -408,47 +407,6 @@ class TestEnv: assert env_var_instance.default == "default" -class TestExecutorType: - """Test the ExecutorType enum and related functionality.""" - - def test_executor_type_values(self): - """Test ExecutorType enum values.""" - assert ExecutorType.THREAD.value == "thread" - assert ExecutorType.PROCESS.value == "process" - assert ExecutorType.MAIN_THREAD.value == "main_thread" - - def test_get_executor_main_thread_mode(self): - """Test executor selection in main thread mode.""" - with ( - patch.object( - environment.REFLEX_COMPILE_EXECUTOR, - "get", - return_value=ExecutorType.MAIN_THREAD, - ), - patch.object( - environment.REFLEX_COMPILE_PROCESSES, "get", return_value=None - ), - patch.object(environment.REFLEX_COMPILE_THREADS, "get", return_value=None), - ): - executor = ExecutorType.get_executor_from_environment() - - # Test the main thread executor functionality - with executor: - future = executor.submit(lambda x: x * 2, 5) - assert future.result() == 10 - - def test_get_executor_returns_executor(self): - """Test that get_executor_from_environment returns an executor.""" - # Test with default values - should return some kind of executor - executor = ExecutorType.get_executor_from_environment() - assert executor is not None - - # Test that we can use it as a context manager - with executor: - future = executor.submit(lambda: "test") - assert future.result() == "test" - - class TestUtilityFunctions: """Test utility functions."""