diff --git a/pyproject.toml b/pyproject.toml index 4bfb76a3608..2ac4db66942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,7 +230,7 @@ asyncio_mode = "auto" [tool.codespell] skip = "docs/*,*.html,examples/*, *.pyi, poetry.lock, uv.lock" -ignore-words-list = "te, TreeE, selectin" +ignore-words-list = "te, TreeE, selectin, asend" [tool.coverage.run] diff --git a/reflex/compiler/plugins/__init__.py b/reflex/compiler/plugins/__init__.py new file mode 100644 index 00000000000..ac5db2c084f --- /dev/null +++ b/reflex/compiler/plugins/__init__.py @@ -0,0 +1,43 @@ +"""Compiler plugin foundations for single-pass page compilation.""" + +from reflex.compiler.plugins.base import ( + BaseContext, + CompileComponentYield, + CompileContext, + CompilerHooks, + CompilerPlugin, + ComponentAndChildren, + PageContext, + PageDefinition, +) +from reflex.compiler.plugins.builtin import ( + ApplyStylePlugin, + ConsolidateAppWrapPlugin, + ConsolidateCustomCodePlugin, + ConsolidateDynamicImportsPlugin, + ConsolidateHooksPlugin, + ConsolidateImportsPlugin, + ConsolidateRefsPlugin, + DefaultPagePlugin, + default_page_plugins, +) + +__all__ = [ + "ApplyStylePlugin", + "BaseContext", + "CompileComponentYield", + "CompileContext", + "CompilerHooks", + "CompilerPlugin", + "ComponentAndChildren", + "ConsolidateAppWrapPlugin", + "ConsolidateCustomCodePlugin", + "ConsolidateDynamicImportsPlugin", + "ConsolidateHooksPlugin", + "ConsolidateImportsPlugin", + "ConsolidateRefsPlugin", + "DefaultPagePlugin", + "PageContext", + "PageDefinition", + "default_page_plugins", +] diff --git a/reflex/compiler/plugins/base.py b/reflex/compiler/plugins/base.py new file mode 100644 index 00000000000..d180f3131da --- /dev/null +++ b/reflex/compiler/plugins/base.py @@ -0,0 +1,505 @@ +"""Core compiler plugin infrastructure: protocols, contexts, and dispatch.""" + +from __future__ import annotations + +import dataclasses +import inspect +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence +from contextvars import ContextVar, Token +from types import TracebackType +from typing import Any, ClassVar, Literal, Protocol, TypeAlias, cast, overload + +from reflex_core.components.component import BaseComponent, Component, StatefulComponent +from reflex_core.utils.imports import ParsedImportDict, collapse_imports, merge_imports +from reflex_core.vars import VarData +from typing_extensions import Self + + +class PageDefinition(Protocol): + """Protocol for page-like objects compiled by :class:`CompileContext`.""" + + route: str + component: Any + + +ComponentAndChildren: TypeAlias = tuple[BaseComponent, tuple[BaseComponent, ...]] +CompileComponentYield: TypeAlias = BaseComponent | ComponentAndChildren | None + + +class CompilerPlugin(Protocol): + """Protocol for compiler plugins that participate in page compilation.""" + + async def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Evaluate a page-like object into a page context. + + Args: + page_fn: The page callable or component-like object to evaluate. + page: The declared page metadata associated with ``page_fn``. + **kwargs: Additional compilation context for advanced plugins. + + Returns: + ``None`` to indicate that the plugin does not handle the page. + """ + return None + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Finalize a page context after its component tree has been traversed.""" + return + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[CompileComponentYield, ComponentAndChildren]: + """Inspect or transform a component during recursive compilation. + + Yields: + Optional replacements before and after child traversal. + """ + if False: # pragma: no cover + yield None + return + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class CompilerHooks: + """Dispatch compiler hooks across an ordered plugin chain.""" + + plugins: tuple[CompilerPlugin, ...] = () + + @staticmethod + def _get_hook_impl( + plugin: CompilerPlugin, + hook_name: str, + ) -> Callable[..., Any] | None: + """Return the concrete hook implementation for a plugin, if any. + + Plugins that inherit the default hook bodies from + :class:`CompilerPlugin` are treated as not implementing the hook and + are skipped by dispatch. + """ + plugin_impl = inspect.getattr_static(type(plugin), hook_name, None) + if plugin_impl is None: + return None + + base_impl = inspect.getattr_static(CompilerPlugin, hook_name, None) + if plugin_impl is base_impl: + return None + + return cast(Callable[..., Any], getattr(plugin, hook_name, None)) + + @overload + async def _dispatch( + self, + hook_name: str, + *args: Any, + stop_on_result: Literal[False] = False, + **kwargs: Any, + ) -> list[Any]: ... + + @overload + async def _dispatch( + self, + hook_name: str, + *args: Any, + stop_on_result: Literal[True], + **kwargs: Any, + ) -> Any | None: ... + + async def _dispatch( + self, + hook_name: str, + *args: Any, + stop_on_result: bool = False, + **kwargs: Any, + ) -> list[Any] | Any | None: + """Dispatch a coroutine hook across all plugins in registration order. + + Args: + hook_name: The plugin hook attribute to invoke. + *args: Positional arguments forwarded to the hook. + stop_on_result: Whether to return immediately on the first non-None + result instead of collecting all results. + **kwargs: Keyword arguments forwarded to the hook. + + Returns: + When ``stop_on_result`` is false, a list of hook return values in + registration order. Otherwise, the first non-None result, or + ``None`` if every plugin returned ``None``. + """ + if stop_on_result: + for plugin in self.plugins: + hook_impl = self._get_hook_impl(plugin, hook_name) + if hook_impl is None: + continue + result = await cast(Awaitable[Any], hook_impl(*args, **kwargs)) + if result is not None: + return result + return None + results: list[Any] = [] + for plugin in self.plugins: + hook_impl = self._get_hook_impl(plugin, hook_name) + if hook_impl is None: + continue + results.append(await cast(Awaitable[Any], hook_impl(*args, **kwargs))) + return results + + async def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext | None: + """Return the first page context produced by the plugin chain.""" + result = await self._dispatch( + "eval_page", + page_fn, + stop_on_result=True, + page=page, + **kwargs, + ) + return cast(PageContext | None, result) + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Run all ``compile_page`` hooks in plugin order.""" + await self._dispatch("compile_page", page_ctx, **kwargs) + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> BaseComponent: + """Walk a component tree once while dispatching component hooks. + + Plugins are entered in registration order before children are visited and + unwound in reverse order after the structural children have been + compiled. Component-valued props are traversed after structural children + for collection and side effects, but their transformed values are not + written back to the parent in this foundational slice. + + Args: + comp: The component to compile. + **kwargs: Additional context shared with plugins. + + Returns: + The compiled component. + """ + active_generators: list[ + AsyncGenerator[CompileComponentYield, ComponentAndChildren] + ] = [] + compiled_component = comp + structural_children: tuple[BaseComponent, ...] | None = None + + try: + for plugin in self.plugins: + hook_impl = self._get_hook_impl(plugin, "compile_component") + if hook_impl is None: + continue + generator = cast( + AsyncGenerator[CompileComponentYield, ComponentAndChildren], + hook_impl(compiled_component, **kwargs), + ) + active_generators.append(generator) + try: + replacement = await anext(generator) + except StopAsyncIteration: + replacement = None + compiled_component, structural_children = self._apply_replacement( + compiled_component, + structural_children, + replacement, + ) + + if isinstance(compiled_component, StatefulComponent): + if not compiled_component.rendered_as_shared: + compiled_component.component = cast( + Component, + await self.compile_component( + compiled_component.component, + **{ + **kwargs, + "stateful_component": compiled_component, + }, + ), + ) + compiled_children = tuple(compiled_component.children) + else: + if structural_children is None: + structural_children = tuple(compiled_component.children) + compiled_children = await self._compile_children( + structural_children, + **kwargs, + ) + + if isinstance(compiled_component, Component): + for prop_component in compiled_component._get_components_in_props(): + await self.compile_component( + prop_component, + **{ + **kwargs, + "in_prop_tree": True, + }, + ) + + for generator in reversed(active_generators): + try: + replacement = await generator.asend(( + compiled_component, + compiled_children, + )) + except StopAsyncIteration: + replacement = None + compiled_component, replacement_children = self._apply_replacement( + compiled_component, + compiled_children, + replacement, + ) + if replacement_children is not compiled_children: + compiled_children = await self._compile_children( + replacement_children, + **kwargs, + ) + + compiled_component.children = list(compiled_children) + return compiled_component + finally: + for generator in reversed(active_generators): + await generator.aclose() + + async def _compile_children( + self, + children: Sequence[BaseComponent], + **kwargs: Any, + ) -> tuple[BaseComponent, ...]: + """Compile a sequence of structural children in order. + + Args: + children: The structural children to compile. + **kwargs: Additional keyword arguments forwarded to the walker. + + Returns: + The compiled children in their original order. + """ + compiled_children = [ + await self.compile_component(child, **kwargs) for child in children + ] + return tuple(compiled_children) + + @staticmethod + def _apply_replacement( + comp: BaseComponent, + children: tuple[BaseComponent, ...] | None, + replacement: CompileComponentYield, + ) -> tuple[BaseComponent, tuple[BaseComponent, ...] | None]: + """Apply a plugin replacement to the current component state. + + Args: + comp: The current component. + children: The current structural children. + replacement: The replacement returned by a plugin hook. + + Returns: + The updated component and structural children tuple. + """ + if replacement is None: + return comp, children + if isinstance(replacement, tuple): + return replacement + return replacement, children + + +@dataclasses.dataclass(kw_only=True) +class BaseContext: + """Async context manager that exposes itself through a class-local context var.""" + + __context_var__: ClassVar[ContextVar[Self | None]] + + _attached_context_token: Token[Self | None] | None = dataclasses.field( + default=None, + init=False, + repr=False, + ) + + @classmethod + def __init_subclass__(cls, **kwargs: Any) -> None: + """Initialize a dedicated context variable for each subclass.""" + super().__init_subclass__(**kwargs) + cls.__context_var__ = ContextVar(cls.__name__, default=None) + + @classmethod + def get(cls) -> Self: + """Return the active context instance for the current task.""" + context = cls.__context_var__.get() + if context is None: + msg = f"No active {cls.__name__} is attached to the current context." + raise RuntimeError(msg) + return context + + async def __aenter__(self) -> Self: + """Attach this context to the current task. + + Returns: + The attached context instance. + """ + if self._attached_context_token is not None: + msg = "Context is already attached and cannot be entered twice." + raise RuntimeError(msg) + self._attached_context_token = type(self).__context_var__.set(self) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Detach this context from the current task.""" + if self._attached_context_token is None: + return + try: + type(self).__context_var__.reset(self._attached_context_token) + finally: + self._attached_context_token = None + + def ensure_context_attached(self) -> None: + """Ensure this instance is the active context for the current task.""" + try: + current = type(self).get() + except RuntimeError as err: + msg = ( + f"{type(self).__name__} must be entered with 'async with' before " + "calling this method." + ) + raise RuntimeError(msg) from err + if current is not self: + msg = f"{type(self).__name__} is not attached to the current task context." + raise RuntimeError(msg) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class PageContext(BaseContext): + """Mutable compilation state for a single page.""" + + name: str + route: str + root_component: BaseComponent + imports: list[ParsedImportDict] = dataclasses.field(default_factory=list) + module_code: dict[str, None] = dataclasses.field(default_factory=dict) + hooks: dict[str, VarData | None] = dataclasses.field(default_factory=dict) + dynamic_imports: set[str] = dataclasses.field(default_factory=set) + refs: dict[str, None] = dataclasses.field(default_factory=dict) + app_wrap_components: dict[tuple[int, str], Component] = dataclasses.field( + default_factory=dict + ) + + def merged_imports(self, *, collapse: bool = False) -> ParsedImportDict: + """Return the imports accumulated for this page. + + Args: + collapse: Whether to deduplicate import vars before returning. + + Returns: + The merged import dictionary. + """ + imports = merge_imports(*self.imports) if self.imports else {} + return collapse_imports(imports) if collapse else imports + + def custom_code_dict(self) -> dict[str, None]: + """Return custom-code snippets keyed like legacy collectors.""" + return dict(self.module_code) + + +@dataclasses.dataclass(slots=True, kw_only=True) +class CompileContext(BaseContext): + """Mutable compilation state for an entire compile run.""" + + pages: Sequence[PageDefinition] + hooks: CompilerHooks = dataclasses.field(default_factory=CompilerHooks) + compiled_pages: dict[str, PageContext] = dataclasses.field(default_factory=dict) + + async def compile(self, **kwargs: Any) -> dict[str, PageContext]: + """Compile all configured pages through the plugin pipeline. + + Args: + **kwargs: Additional keyword arguments forwarded to plugin hooks. + + Returns: + The compiled pages keyed by route. + + Raises: + RuntimeError: If no plugin can evaluate a page, or if two compiled + pages resolve to the same route. + """ + self.ensure_context_attached() + self.compiled_pages.clear() + + for page in self.pages: + page_fn = page.component + page_ctx = await self.hooks.eval_page( + page_fn, + page=page, + compile_context=self, + **kwargs, + ) + if page_ctx is None: + page_name = getattr(page_fn, "__name__", repr(page_fn)) + msg = ( + f"No compiler plugin was able to evaluate page {page.route!r} " + f"({page_name})." + ) + raise RuntimeError(msg) + if page_ctx.route in self.compiled_pages: + msg = f"Duplicate compiled page route {page_ctx.route!r}." + raise RuntimeError(msg) + + async with page_ctx: + page_ctx.root_component = await self.hooks.compile_component( + page_ctx.root_component, + page=page, + page_context=page_ctx, + compile_context=self, + **kwargs, + ) + await self.hooks.compile_page( + page_ctx, + page=page, + compile_context=self, + **kwargs, + ) + + self.compiled_pages[page_ctx.route] = page_ctx + + return self.compiled_pages + + +__all__ = [ + "BaseContext", + "CompileComponentYield", + "CompileContext", + "CompilerHooks", + "CompilerPlugin", + "ComponentAndChildren", + "PageContext", + "PageDefinition", +] diff --git a/reflex/compiler/plugins/builtin.py b/reflex/compiler/plugins/builtin.py new file mode 100644 index 00000000000..e90c593d697 --- /dev/null +++ b/reflex/compiler/plugins/builtin.py @@ -0,0 +1,409 @@ +"""Built-in compiler plugins and the default plugin pipeline.""" + +from __future__ import annotations + +import dataclasses +from collections.abc import AsyncGenerator +from typing import Any + +from reflex_components_core.base.fragment import Fragment +from reflex_core.components.component import ( + BaseComponent, + Component, + ComponentStyle, + StatefulComponent, +) +from reflex_core.config import get_config +from reflex_core.utils.format import make_default_page_title +from reflex_core.utils.imports import collapse_imports, merge_imports +from reflex_core.vars import VarData + +from reflex.compiler import utils +from reflex.compiler.plugins.base import ( + CompilerPlugin, + ComponentAndChildren, + PageContext, + PageDefinition, +) + + +@dataclasses.dataclass(frozen=True, slots=True) +class DefaultPagePlugin(CompilerPlugin): + """Evaluate an unevaluated page into a mutable page context.""" + + async def eval_page( + self, + page_fn: Any, + /, + *, + page: PageDefinition, + **kwargs: Any, + ) -> PageContext: + """Evaluate the page function and attach legacy page metadata. + + Returns: + The initialized page context for the evaluated page. + """ + from reflex.compiler import compiler + + try: + component = compiler.into_component(page_fn) + component = Fragment.create(component) + + meta_args = { + "title": getattr(page, "title", None) + or make_default_page_title(get_config().app_name, page.route), + "image": getattr(page, "image", ""), + "meta": getattr(page, "meta", ()), + } + if (description := getattr(page, "description", None)) is not None: + meta_args["description"] = description + + utils.add_meta(component, **meta_args) + except Exception as err: + if hasattr(err, "add_note"): + err.add_note(f"Happened while evaluating page {page.route!r}") + raise + + return PageContext( + name=getattr(page_fn, "__name__", page.route), + route=page.route, + root_component=component, + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class ApplyStylePlugin(CompilerPlugin): + """Apply app-level styles in the descending phase of the walk.""" + + style: ComponentStyle | None = None + theme: Component | None = None + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + stateful_component: StatefulComponent | None = None, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Apply the non-recursive portion of ``_add_style_recursive``.""" + del kwargs, stateful_component + if self.style is not None and isinstance(comp, Component) and not in_prop_tree: + if type(comp)._add_style != Component._add_style: + msg = "Do not override _add_style directly. Use add_style instead." + raise UserWarning(msg) + + new_style = comp._add_style() + style_vars = [new_style._var_data] + + component_style = comp._get_component_style(self.style) + if component_style: + new_style.update(component_style) + style_vars.append(component_style._var_data) + + new_style.update(comp.style) + style_vars.append(comp.style._var_data) + new_style._var_data = VarData.merge(*style_vars) + comp.style = new_style + + yield + + +class ConsolidateImportsPlugin(CompilerPlugin): + """Collect per-component imports and merge them after traversal.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect non-recursive imports for structural components.""" + del kwargs + if isinstance(comp, StatefulComponent): + if comp.rendered_as_shared: + PageContext.get().imports.append(comp._get_all_imports()) + yield + return + + if not in_prop_tree and isinstance(comp, Component): + imports = comp._get_imports() + if imports: + PageContext.get().imports.append(imports) + + yield + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + """Collapse collected imports into a single legacy-shaped entry.""" + del kwargs + page_ctx.imports = ( + [collapse_imports(merge_imports(*page_ctx.imports))] + if page_ctx.imports + else [] + ) + + +class ConsolidateHooksPlugin(CompilerPlugin): + """Collect component hooks while skipping prop and stateful subtrees.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + stateful_component: StatefulComponent | None = None, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect the single-component hook contributions.""" + del kwargs + compiled_component, _ = yield + if ( + in_prop_tree + or stateful_component is not None + or isinstance(compiled_component, StatefulComponent) + or not isinstance(compiled_component, Component) + ): + return + + hooks = {} + hooks.update(compiled_component._get_hooks_internal()) + if (user_hooks := compiled_component._get_hooks()) is not None: + hooks[user_hooks] = None + hooks.update(compiled_component._get_added_hooks()) + PageContext.get().hooks.update(hooks) + + +@dataclasses.dataclass(frozen=True, slots=True) +class ConsolidateCustomCodePlugin(CompilerPlugin): + """Collect custom module code while preserving legacy ordering.""" + + stateful_custom_code_export: bool = False + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect custom code in legacy order for the active subtree.""" + del kwargs + page_ctx = PageContext.get() + + if isinstance(comp, StatefulComponent): + yield + if not comp.rendered_as_shared: + page_ctx.module_code[ + comp._render_stateful_code(export=self.stateful_custom_code_export) + ] = None + return + + if in_prop_tree or not isinstance(comp, Component): + yield + return + + if (custom_code := comp._get_custom_code()) is not None: + page_ctx.module_code[custom_code] = None + + for prop_component in comp._get_components_in_props(): + page_ctx.module_code.update(self._collect_prop_custom_code(prop_component)) + + for clz in comp._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(comp): + page_ctx.module_code[item] = None + + yield + + def _collect_prop_custom_code( + self, + component: BaseComponent, + ) -> dict[str, None]: + """Recursively collect custom code for a prop-component subtree. + + Returns: + The collected custom-code snippets keyed in legacy order. + """ + if isinstance(component, StatefulComponent): + if component.rendered_as_shared: + return {} + + code = self._collect_prop_custom_code(component.component) + code[ + component._render_stateful_code(export=self.stateful_custom_code_export) + ] = None + return code + + if not isinstance(component, Component): + return component._get_all_custom_code() + + code: dict[str, None] = {} + if (custom_code := component._get_custom_code()) is not None: + code[custom_code] = None + + for prop_component in component._get_components_in_props(): + code.update(self._collect_prop_custom_code(prop_component)) + + for clz in component._iter_parent_classes_with_method("add_custom_code"): + for item in clz.add_custom_code(component): + code[item] = None + + for child in component.children: + code.update(self._collect_prop_custom_code(child)) + + return code + + +class ConsolidateDynamicImportsPlugin(CompilerPlugin): + """Collect dynamic imports from the active component tree.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect the current component's dynamic import.""" + del kwargs + compiled_component, _ = yield + if isinstance(compiled_component, StatefulComponent) or not isinstance( + compiled_component, Component + ): + return + if dynamic_import := compiled_component._get_dynamic_imports(): + PageContext.get().dynamic_imports.add(dynamic_import) + + +class ConsolidateRefsPlugin(CompilerPlugin): + """Collect refs from the active component tree.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect the current component ref when present.""" + del kwargs + compiled_component, _ = yield + if isinstance(compiled_component, StatefulComponent) or not isinstance( + compiled_component, Component + ): + return + if (ref := compiled_component.get_ref()) is not None: + PageContext.get().refs[ref] = None + + +class ConsolidateAppWrapPlugin(CompilerPlugin): + """Collect app-wrap components using the page walk plus wrapper recursion.""" + + async def compile_component( + self, + comp: BaseComponent, + /, + *, + in_prop_tree: bool = False, + stateful_component: StatefulComponent | None = None, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + """Collect direct wrappers and recursively expand wrapper subtrees.""" + del kwargs + compiled_component, _ = yield + if ( + in_prop_tree + or stateful_component is not None + or not isinstance(compiled_component, Component) + ): + return + + page_ctx = PageContext.get() + direct_wrappers = compiled_component._get_app_wrap_components() + if not direct_wrappers: + return + + ignore_ids = {id(wrapper) for wrapper in page_ctx.app_wrap_components.values()} + page_ctx.app_wrap_components.update(direct_wrappers) + for wrapper in direct_wrappers.values(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + page_ctx.app_wrap_components.update( + self._collect_wrapper_subtree(wrapper, ignore_ids) + ) + + def _collect_wrapper_subtree( + self, + component: Component, + ignore_ids: set[int], + ) -> dict[tuple[int, str], Component]: + """Collect app-wrap components reachable through a wrapper subtree. + + Returns: + The nested wrapper mapping discovered from the wrapper subtree. + """ + components: dict[tuple[int, str], Component] = {} + + direct_wrappers = component._get_app_wrap_components() + for key, wrapper in direct_wrappers.items(): + wrapper_id = id(wrapper) + if wrapper_id in ignore_ids: + continue + ignore_ids.add(wrapper_id) + components[key] = wrapper + components.update(self._collect_wrapper_subtree(wrapper, ignore_ids)) + + for child in component.children: + if not isinstance(child, Component): + continue + child_id = id(child) + if child_id in ignore_ids: + continue + ignore_ids.add(child_id) + components.update(self._collect_wrapper_subtree(child, ignore_ids)) + + return components + + +def default_page_plugins( + *, + style: ComponentStyle | None = None, + theme: Component | None = None, + stateful_custom_code_export: bool = False, +) -> tuple[CompilerPlugin, ...]: + """Return the default compiler plugin ordering for page compilation.""" + return ( + DefaultPagePlugin(), + ApplyStylePlugin(style=style, theme=theme), + ConsolidateCustomCodePlugin( + stateful_custom_code_export=stateful_custom_code_export + ), + ConsolidateDynamicImportsPlugin(), + ConsolidateRefsPlugin(), + ConsolidateHooksPlugin(), + ConsolidateAppWrapPlugin(), + ConsolidateImportsPlugin(), + ) + + +__all__ = [ + "ApplyStylePlugin", + "ConsolidateAppWrapPlugin", + "ConsolidateCustomCodePlugin", + "ConsolidateDynamicImportsPlugin", + "ConsolidateHooksPlugin", + "ConsolidateImportsPlugin", + "ConsolidateRefsPlugin", + "DefaultPagePlugin", + "default_page_plugins", +] diff --git a/tests/units/compiler/test_plugins.py b/tests/units/compiler/test_plugins.py new file mode 100644 index 00000000000..22e2b456ddf --- /dev/null +++ b/tests/units/compiler/test_plugins.py @@ -0,0 +1,762 @@ +# ruff: noqa: D101, D102 + +import asyncio +import dataclasses +from collections.abc import AsyncGenerator, Callable +from typing import Any, cast + +import pytest +from reflex_components_core.base.fragment import Fragment +from reflex_core.components.component import ( + BaseComponent, + Component, + ComponentStyle, + field, +) +from reflex_core.utils import format as format_utils +from reflex_core.utils.imports import ImportVar, collapse_imports, merge_imports + +from reflex.compiler.plugins import ( + ApplyStylePlugin, + BaseContext, + CompileContext, + CompilerHooks, + CompilerPlugin, + ComponentAndChildren, + ConsolidateAppWrapPlugin, + ConsolidateCustomCodePlugin, + ConsolidateDynamicImportsPlugin, + ConsolidateHooksPlugin, + ConsolidateImportsPlugin, + ConsolidateRefsPlugin, + DefaultPagePlugin, + PageContext, + default_page_plugins, +) + + +@dataclasses.dataclass(slots=True) +class FakePage: + route: str + component: Callable[[], Component] + title: str | None = None + description: str | None = None + image: str = "" + meta: tuple[dict[str, Any], ...] = () + + +class StubCompilerPlugin(CompilerPlugin): + pass + + +class WrapperComponent(Component): + tag = "WrapperComponent" + library = "wrapper-lib" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(20, "NestedWrap"): Fragment.create()} + + +class RootComponent(Component): + tag = "RootComponent" + library = "root-lib" + + slot: Component | None = field(default=None) + + def add_style(self) -> dict[str, Any] | None: + return {"display": "flex"} + + def add_custom_code(self) -> list[str]: + return ["const rootAddedCode = 1;"] + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(10, "Wrap"): WrapperComponent.create()} + + +class ChildComponent(Component): + tag = "ChildComponent" + library = "child-lib" + + def add_style(self) -> dict[str, Any] | None: + return {"align_items": "center"} + + def add_custom_code(self) -> list[str]: + return ["const childAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const childCustomCode = 1;" + + def _get_hooks(self) -> str | None: + return "const childHook = useChildHook();" + + +class PropComponent(Component): + tag = "PropComponent" + library = "prop-lib" + + def add_custom_code(self) -> list[str]: + return ["const propAddedCode = 1;"] + + def _get_custom_code(self) -> str | None: + return "const propCustomCode = 1;" + + def _get_dynamic_imports(self) -> str | None: + return "dynamic(() => import('prop-lib'))" + + def _get_hooks(self) -> str | None: + return "const propHook = usePropHook();" + + @staticmethod + def _get_app_wrap_components() -> dict[tuple[int, str], Component]: + return {(15, "PropWrap"): Fragment.create()} + + +async def collect_page_context( + component: BaseComponent, + *, + plugins: tuple[Any, ...], +) -> PageContext: + page_ctx = PageContext( + name="page", + route="/page", + root_component=component, + ) + hooks = CompilerHooks(plugins=plugins) + + async with page_ctx: + page_ctx.root_component = await hooks.compile_component(page_ctx.root_component) + await hooks.compile_page(page_ctx) + + return page_ctx + + +def create_component_tree() -> RootComponent: + return RootComponent.create( + ChildComponent.create(id="child-id", style={"color": "red"}), + slot=PropComponent.create(id="prop-id", style={"opacity": "0.5"}), + style={"margin": "0"}, + ) + + +def page_style() -> ComponentStyle: + return { + RootComponent: {"padding": "1rem"}, + ChildComponent: {"font_size": "12px"}, + PropComponent: {"border": "1px solid green"}, + } + + +class EvalPagePlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> PageContext: + component = page_fn() if callable(page_fn) else page_fn + if not isinstance(component, BaseComponent): + msg = f"Expected a BaseComponent, got {type(component).__name__}." + raise TypeError(msg) + name = getattr(page_fn, "__name__", page.route) + return PageContext( + name=name, + route=page.route, + root_component=component, + ) + + +class CollectPageDataPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + compiled_component, _children = yield + if isinstance(compiled_component, Component): + page_ctx = PageContext.get() + imports = compiled_component._get_imports() + if imports: + page_ctx.imports.append(imports) + page_ctx.hooks.update(compiled_component._get_hooks_internal()) + if hooks := compiled_component._get_hooks(): + page_ctx.hooks[hooks] = None + page_ctx.hooks.update(compiled_component._get_added_hooks()) + if module_code := compiled_component._get_custom_code(): + page_ctx.module_code[module_code] = None + if dynamic_import := compiled_component._get_dynamic_imports(): + page_ctx.dynamic_imports.add(dynamic_import) + if ref := compiled_component.get_ref(): + page_ctx.refs[ref] = None + page_ctx.app_wrap_components.update( + compiled_component._get_app_wrap_components() + ) + + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + page_ctx.imports = ( + [collapse_imports(merge_imports(*page_ctx.imports))] + if page_ctx.imports + else [] + ) + + +@pytest.mark.asyncio +async def test_eval_page_uses_first_non_none_result() -> None: + calls: list[str] = [] + page = FakePage(route="/demo", component=lambda: Fragment.create()) + + class NoMatchPlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> None: + del page_fn, page, kwargs + calls.append("no-match") + return + + class MatchPlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> PageContext: + calls.append("match") + return PageContext( + name="page", + route=page.route, + root_component=page_fn(), + ) + + class UnreachablePlugin(StubCompilerPlugin): + async def eval_page( + self, + page_fn: Any, + /, + *, + page: FakePage, + **kwargs: Any, + ) -> PageContext: + del page_fn, page, kwargs + calls.append("unreachable") + msg = "eval_page should stop at the first page context" + raise AssertionError(msg) + + hooks = CompilerHooks(plugins=(NoMatchPlugin(), MatchPlugin(), UnreachablePlugin())) + + page_ctx = await hooks.eval_page(page.component, page=page, compile_context=None) + + assert page_ctx is not None + assert page_ctx.route == "/demo" + assert calls == ["no-match", "match"] + + +@pytest.mark.asyncio +async def test_compile_page_runs_plugins_in_registration_order() -> None: + calls: list[str] = [] + page_ctx = PageContext( + name="page", + route="/ordered", + root_component=Fragment.create(), + ) + + class FirstPlugin(StubCompilerPlugin): + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + calls.append("first") + + class SecondPlugin(StubCompilerPlugin): + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + calls.append("second") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + + await hooks.compile_page(page_ctx, compile_context=None) + + assert calls == ["first", "second"] + + +@pytest.mark.asyncio +async def test_compile_page_skips_inherited_protocol_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + page_ctx = PageContext( + name="page", + route="/ordered", + root_component=Fragment.create(), + ) + calls: list[str] = [] + + async def fail_compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + del self, page_ctx, kwargs + await asyncio.sleep(0) + msg = "Inherited protocol compile_page hook should be skipped." + raise AssertionError(msg) + + monkeypatch.setattr(CompilerPlugin, "compile_page", fail_compile_page) + + class ProtocolOnlyPlugin(CompilerPlugin): + pass + + class RealPlugin(StubCompilerPlugin): + async def compile_page( + self, + page_ctx: PageContext, + /, + **kwargs: Any, + ) -> None: + calls.append("real") + + hooks = CompilerHooks(plugins=(ProtocolOnlyPlugin(), RealPlugin())) + + await hooks.compile_page(page_ctx, compile_context=None) + + assert calls == ["real"] + + +@pytest.mark.asyncio +async def test_compile_component_orders_pre_and_post_by_plugin() -> None: + events: list[str] = [] + root = RootComponent.create() + + class FirstPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + events.append("first:pre") + yield + events.append("first:post") + + class SecondPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + events.append("second:pre") + yield + events.append("second:post") + + hooks = CompilerHooks(plugins=(FirstPlugin(), SecondPlugin())) + + compiled_root = await hooks.compile_component(root) + + assert compiled_root is root + assert events == ["first:pre", "second:pre", "second:post", "first:post"] + + +@pytest.mark.asyncio +async def test_compile_component_skips_inherited_protocol_hook( + monkeypatch: pytest.MonkeyPatch, +) -> None: + events: list[str] = [] + root = RootComponent.create() + + async def fail_compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + del self, comp, kwargs + await asyncio.sleep(0) + msg = "Inherited protocol compile_component hook should be skipped." + raise AssertionError(msg) + if False: # pragma: no cover + yield None + + monkeypatch.setattr( + CompilerPlugin, + "compile_component", + fail_compile_component, + ) + + class ProtocolOnlyPlugin(CompilerPlugin): + pass + + class RealPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + events.append("real:pre") + yield + events.append("real:post") + + hooks = CompilerHooks(plugins=(ProtocolOnlyPlugin(), RealPlugin())) + + compiled_root = await hooks.compile_component(root) + + assert compiled_root is root + assert events == ["real:pre", "real:post"] + + +@pytest.mark.asyncio +async def test_compile_component_traverses_children_before_prop_components() -> None: + visited: list[str] = [] + root = RootComponent.create( + ChildComponent.create(), + slot=PropComponent.create(), + ) + + class VisitPlugin(StubCompilerPlugin): + async def compile_component( + self, + comp: BaseComponent, + /, + **kwargs: Any, + ) -> AsyncGenerator[None, ComponentAndChildren]: + if isinstance(comp, Component): + visited.append(comp.tag or type(comp).__name__) + yield + + hooks = CompilerHooks(plugins=(VisitPlugin(),)) + await hooks.compile_component(root) + + assert visited == ["RootComponent", "ChildComponent", "PropComponent"] + + +@pytest.mark.asyncio +async def test_context_lifecycle_and_cleanup() -> None: + compile_ctx = CompileContext(pages=[], hooks=CompilerHooks()) + page_ctx = PageContext( + name="page", + route="/ctx", + root_component=Fragment.create(), + ) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + with pytest.raises(RuntimeError, match="must be entered with 'async with'"): + compile_ctx.ensure_context_attached() + + async with compile_ctx: + assert CompileContext.get() is compile_ctx + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + async with page_ctx: + assert CompileContext.get() is compile_ctx + assert PageContext.get() is page_ctx + page_ctx.ensure_context_attached() + with pytest.raises(RuntimeError, match="No active PageContext"): + PageContext.get() + assert CompileContext.get() is compile_ctx + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + with pytest.raises(ValueError, match="boom"): + async with compile_ctx: + msg = "boom" + raise ValueError(msg) + + with pytest.raises(RuntimeError, match="No active CompileContext"): + CompileContext.get() + + +def test_page_context_default_factories_are_isolated() -> None: + page_ctx_a = PageContext( + name="a", + route="/a", + root_component=Fragment.create(), + ) + page_ctx_b = PageContext( + name="b", + route="/b", + root_component=Fragment.create(), + ) + + page_ctx_a.imports.append({"lib-a": [ImportVar(tag="ThingA")]}) + page_ctx_a.module_code["const a = 1;"] = None + page_ctx_a.hooks["hookA"] = None + page_ctx_a.dynamic_imports.add("dynamic-a") + page_ctx_a.refs["refA"] = None + page_ctx_a.app_wrap_components[1, "WrapA"] = Fragment.create() + + assert page_ctx_b.imports == [] + assert page_ctx_b.module_code == {} + assert page_ctx_b.hooks == {} + assert page_ctx_b.dynamic_imports == set() + assert page_ctx_b.refs == {} + assert page_ctx_b.app_wrap_components == {} + + +def test_page_context_helpers_preserve_accumulated_values() -> None: + page_ctx = PageContext( + name="page", + route="/page", + root_component=Fragment.create(), + ) + page_ctx.imports.extend([ + {"lib-a": [ImportVar(tag="ThingA")]}, + {"lib-a": [ImportVar(tag="ThingB")], "lib-b": [ImportVar(tag="ThingC")]}, + ]) + page_ctx.module_code["const first = 1;"] = None + page_ctx.module_code["const second = 2;"] = None + + assert page_ctx.merged_imports() == merge_imports(*page_ctx.imports) + assert page_ctx.merged_imports(collapse=True) == collapse_imports( + merge_imports(*page_ctx.imports) + ) + assert list(page_ctx.custom_code_dict()) == [ + "const first = 1;", + "const second = 2;", + ] + + +def test_base_context_subclasses_initialize_distinct_context_vars() -> None: + class DynamicContext(BaseContext): + pass + + class AnotherDynamicContext(BaseContext): + pass + + assert DynamicContext.__context_var__ is not AnotherDynamicContext.__context_var__ + + +@pytest.mark.asyncio +async def test_apply_style_plugin_matches_legacy_recursive_behavior() -> None: + legacy_component = create_component_tree() + plugin_component = create_component_tree() + style = page_style() + + legacy_component._add_style_recursive(style) + page_ctx = await collect_page_context( + plugin_component, + plugins=(ApplyStylePlugin(style=style),), + ) + compiled_root = cast(RootComponent, page_ctx.root_component) + assert compiled_root.slot is not None + assert legacy_component.slot is not None + + assert compiled_root.render() == legacy_component.render() + assert compiled_root.slot.render() == legacy_component.slot.render() + + +@pytest.mark.asyncio +async def test_consolidate_imports_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateImportsPlugin(),), + ) + + assert page_ctx.merged_imports(collapse=True) == root._get_all_imports( + collapse=True + ) + + +@pytest.mark.asyncio +async def test_consolidate_hooks_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateHooksPlugin(),), + ) + + assert page_ctx.hooks == root._get_all_hooks() + assert "const propHook = usePropHook();" not in page_ctx.hooks + + +@pytest.mark.asyncio +async def test_consolidate_custom_code_plugin_matches_legacy_recursive_collector() -> ( + None +): + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateCustomCodePlugin(),), + ) + + assert page_ctx.custom_code_dict() == root._get_all_custom_code() + assert list(page_ctx.custom_code_dict()) == list(root._get_all_custom_code()) + + +@pytest.mark.asyncio +async def test_consolidate_dynamic_imports_plugin_matches_legacy_recursive_collector() -> ( + None +): + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateDynamicImportsPlugin(),), + ) + + assert page_ctx.dynamic_imports == root._get_all_dynamic_imports() + + +@pytest.mark.asyncio +async def test_consolidate_refs_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateRefsPlugin(),), + ) + + assert page_ctx.refs == root._get_all_refs() + + +@pytest.mark.asyncio +async def test_consolidate_app_wrap_plugin_matches_legacy_recursive_collector() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=(ConsolidateAppWrapPlugin(),), + ) + + assert ( + page_ctx.app_wrap_components.keys() + == root._get_all_app_wrap_components().keys() + ) + assert (15, "PropWrap") not in page_ctx.app_wrap_components + assert (20, "NestedWrap") in page_ctx.app_wrap_components + + +@pytest.mark.asyncio +async def test_default_page_plugin_evaluates_page_like_legacy_compile_path() -> None: + page = FakePage( + route="/demo", + component=create_component_tree, + title="Demo", + description="Demo page", + image="demo.png", + meta=({"name": "robots", "content": "index,follow"},), + ) + hooks = CompilerHooks(plugins=(DefaultPagePlugin(),)) + + page_ctx = await hooks.eval_page(page.component, page=page) + + assert page_ctx is not None + assert page_ctx.route == "/demo" + assert page_ctx.name == "create_component_tree" + assert any(child.tag == "title" for child in page_ctx.root_component.children) + assert any(child.tag == "meta" for child in page_ctx.root_component.children) + + +@pytest.mark.asyncio +async def test_default_plugin_pipeline_matches_legacy_page_context_data() -> None: + root = create_component_tree() + + page_ctx = await collect_page_context( + root, + plugins=default_page_plugins(), + ) + + assert page_ctx.merged_imports(collapse=True) == root._get_all_imports( + collapse=True + ) + assert page_ctx.hooks == root._get_all_hooks() + assert page_ctx.custom_code_dict() == root._get_all_custom_code() + assert page_ctx.dynamic_imports == root._get_all_dynamic_imports() + assert page_ctx.refs == root._get_all_refs() + assert ( + page_ctx.app_wrap_components.keys() + == root._get_all_app_wrap_components().keys() + ) + + +@pytest.mark.asyncio +async def test_compile_context_compiles_pages_and_accumulates_page_data() -> None: + page = FakePage( + route="/demo", + component=lambda: RootComponent.create( + ChildComponent.create(id="child-id"), + slot=PropComponent.create(), + ), + ) + compile_ctx = CompileContext( + pages=[page], + hooks=CompilerHooks( + plugins=(EvalPagePlugin(), CollectPageDataPlugin()), + ), + ) + + async with compile_ctx: + compiled_pages = await compile_ctx.compile() + + assert compiled_pages is compile_ctx.compiled_pages + assert list(compiled_pages) == ["/demo"] + + page_ctx = compiled_pages["/demo"] + assert page_ctx.name == "" + assert page_ctx.route == "/demo" + assert page_ctx.imports + assert set(page_ctx.imports[0]) >= {"root-lib", "child-lib", "prop-lib", "react"} + assert page_ctx.module_code == { + "const childCustomCode = 1;": None, + "const propCustomCode = 1;": None, + } + assert page_ctx.dynamic_imports == {"dynamic(() => import('prop-lib'))"} + assert any("useChildHook" in hook for hook in page_ctx.hooks) + assert any("useRef" in hook for hook in page_ctx.hooks) + assert page_ctx.refs == {format_utils.format_ref("child-id"): None} + assert (10, "Wrap") in page_ctx.app_wrap_components + assert (15, "PropWrap") in page_ctx.app_wrap_components + + +@pytest.mark.asyncio +async def test_compile_context_rejects_duplicate_routes() -> None: + pages = [ + FakePage(route="/duplicate", component=lambda: Fragment.create()), + FakePage(route="/duplicate", component=lambda: Fragment.create()), + ] + compile_ctx = CompileContext( + pages=pages, + hooks=CompilerHooks(plugins=(EvalPagePlugin(),)), + ) + + async with compile_ctx: + with pytest.raises(RuntimeError, match="Duplicate compiled page route"): + await compile_ctx.compile() + + +@pytest.mark.asyncio +async def test_compile_context_requires_attached_context() -> None: + compile_ctx = CompileContext( + pages=[], + hooks=CompilerHooks(), + ) + + with pytest.raises(RuntimeError, match="must be entered with 'async with'"): + await compile_ctx.compile() diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index d5fee452c5a..5f5ba11019a 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -306,9 +306,16 @@ async def test_oplock_contention_queue( state_manager_2._oplock_enabled = True modify_started = asyncio.Event() - modify_2_started = asyncio.Event() + contenders_started = asyncio.Event() modify_1_continue = asyncio.Event() modify_2_continue = asyncio.Event() + contender_started_count = 0 + + def mark_contender_started() -> None: + nonlocal contender_started_count + contender_started_count += 1 + if contender_started_count == 2: + contenders_started.set() async def modify_1(): async with state_manager_redis.modify_state( @@ -321,7 +328,7 @@ async def modify_1(): async def modify_2(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -331,7 +338,7 @@ async def modify_2(): async def modify_3(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -343,7 +350,7 @@ async def modify_3(): task_2 = asyncio.create_task(modify_2()) task_3 = asyncio.create_task(modify_3()) - await modify_2_started.wait() + await contenders_started.wait() # Let modify 1 complete modify_1_continue.set() @@ -407,9 +414,16 @@ async def test_oplock_contention_no_lease( state_manager_3._oplock_enabled = True modify_started = asyncio.Event() - modify_2_started = asyncio.Event() + contenders_started = asyncio.Event() modify_1_continue = asyncio.Event() modify_2_continue = asyncio.Event() + contender_started_count = 0 + + def mark_contender_started() -> None: + nonlocal contender_started_count + contender_started_count += 1 + if contender_started_count == 2: + contenders_started.set() async def modify_1(): async with state_manager_redis.modify_state( @@ -422,7 +436,7 @@ async def modify_1(): async def modify_2(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: @@ -432,7 +446,7 @@ async def modify_2(): async def modify_3(): await modify_started.wait() - modify_2_started.set() + mark_contender_started() async with state_manager_3.modify_state( _substate_key(token, root_state), ) as new_state: @@ -444,7 +458,7 @@ async def modify_3(): task_2 = asyncio.create_task(modify_2()) task_3 = asyncio.create_task(modify_3()) - await modify_2_started.wait() + await contenders_started.wait() # Let modify 1 complete modify_1_continue.set()