diff --git a/packages/reflex-base/src/reflex_base/.templates/web/utils/state.js b/packages/reflex-base/src/reflex_base/.templates/web/utils/state.js index fa51f43962f..fd2b5e5741b 100644 --- a/packages/reflex-base/src/reflex_base/.templates/web/utils/state.js +++ b/packages/reflex-base/src/reflex_base/.templates/web/utils/state.js @@ -40,8 +40,6 @@ const cookies = new Cookies(); // Dictionary holding component references. export const refs = {}; -// Flag ensures that only one event is processing on the backend concurrently. -let event_processing = false; // Array holding pending events to be processed. const event_queue = []; @@ -203,14 +201,12 @@ function urlFrom(string) { * @param socket The socket object to send the event on. * @param navigate The navigate function from useNavigate * @param params The params object from useParams - * - * @returns True if the event was sent, false if it was handled locally. */ export const applyEvent = async (event, socket, navigate, params) => { // Handle special events if (event.name == "_redirect") { if ((event.payload.path ?? undefined) === undefined) { - return false; + return; } if (event.payload.external) { window.open( @@ -218,7 +214,7 @@ export const applyEvent = async (event, socket, navigate, params) => { "_blank", "noopener" + (event.payload.popup ? ",popup" : ""), ); - return false; + return; } const url = urlFrom(event.payload.path); let pathname = event.payload.path; @@ -226,7 +222,7 @@ export const applyEvent = async (event, socket, navigate, params) => { if (url.host !== window.location.host) { // External URL window.location.assign(event.payload.path); - return false; + return; } else { pathname = url.pathname + url.search + url.hash; } @@ -236,37 +232,37 @@ export const applyEvent = async (event, socket, navigate, params) => { } else { navigate(pathname); } - return false; + return; } if (event.name == "_remove_cookie") { cookies.remove(event.payload.key, { ...event.payload.options }); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_clear_local_storage") { localStorage.clear(); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_remove_local_storage") { localStorage.removeItem(event.payload.key); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_clear_session_storage") { sessionStorage.clear(); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_remove_session_storage") { sessionStorage.removeItem(event.payload.key); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_download") { @@ -285,7 +281,7 @@ export const applyEvent = async (event, socket, navigate, params) => { a.download = event.payload.filename; a.click(); a.remove(); - return false; + return; } if (event.name == "_set_focus") { @@ -299,7 +295,7 @@ export const applyEvent = async (event, socket, navigate, params) => { } else { current.focus(); } - return false; + return; } if (event.name == "_blur_focus") { @@ -313,7 +309,7 @@ export const applyEvent = async (event, socket, navigate, params) => { } else { current.blur(); } - return false; + return; } if (event.name == "_set_value") { @@ -322,7 +318,7 @@ export const applyEvent = async (event, socket, navigate, params) => { if (ref.current) { ref.current.value = event.payload.value; } - return false; + return; } if ( @@ -348,7 +344,7 @@ export const applyEvent = async (event, socket, navigate, params) => { window.onerror(e.message, null, null, null, e); } } - return false; + return; } if (event.name == "_call_script" || event.name == "_call_function") { @@ -375,11 +371,10 @@ export const applyEvent = async (event, socket, navigate, params) => { window.onerror(e.message, null, null, null, e); } } - return false; + return; } // Update token and router data (if missing). - event.token = getToken(); if ( event.router_data === undefined || Object.keys(event.router_data).length === 0 @@ -387,24 +382,24 @@ export const applyEvent = async (event, socket, navigate, params) => { // Since we don't have router directly, we need to get info from our hooks event.router_data = { pathname: window.location.pathname, - query: { - ...Object.fromEntries(new URLSearchParams(window.location.search)), - ...params(), - }, asPath: window.location.pathname + window.location.search + window.location.hash, }; + const query = { + ...Object.fromEntries(new URLSearchParams(window.location.search)), + ...params.current, + }; + if (query && Object.keys(query).length > 0) { + event.router_data.query = query; + } } // Send the event to the server. if (socket) { socket.emit("event", event); - return true; } - - return false; }; /** @@ -413,11 +408,8 @@ export const applyEvent = async (event, socket, navigate, params) => { * @param socket The socket object to send the response event(s) on. * @param navigate The navigate function from React Router * @param params The params object from React Router - * - * @returns Whether the event was sent. */ export const applyRestEvent = async (event, socket, navigate, params) => { - let eventSent = false; if (event.handler === "uploadFiles") { // Start upload, but do not wait for it, which would block other events. uploadFiles( @@ -431,9 +423,7 @@ export const applyRestEvent = async (event, socket, navigate, params) => { getBackendURL, getToken, ); - return false; } - return eventSent; }; /** @@ -487,28 +477,21 @@ export const processEvent = async (socket, navigate, params) => { } // Only proceed if we're not already processing an event. - if (event_queue.length === 0 || event_processing) { + if (event_queue.length === 0) { return; } - // Set processing to true to block other events from being processed. - event_processing = true; - // Apply the next event in the queue. const event = event_queue.shift(); - let eventSent = false; // Process events with handlers via REST and all others via websockets. if (event.handler) { - eventSent = await applyRestEvent(event, socket, navigate, params); + await applyRestEvent(event, socket, navigate, params); } else { - eventSent = await applyEvent(event, socket, navigate, params); + await applyEvent(event, socket, navigate, params); } - // If no event was sent, set processing to false. - if (!eventSent) { - event_processing = false; - // recursively call processEvent to drain the queue, since there is - // no state update to trigger the useEffect event loop. + // Process any remaining events. + if (event_queue.length > 0) { await processEvent(socket, navigate, params); } }; @@ -621,17 +604,11 @@ export const connect = async ( window.addEventListener("unload", disconnectTrigger); if (socket.current.rehydrate) { socket.current.rehydrate = false; - queueEvents( - initialEvents(), - socket, - true, - navigate, - () => params.current, - ); + queueEvents(initialEvents(), socket, true, navigate, params); } // Drain any initial events from the queue. - while (event_queue.length > 0 && !event_processing) { - await processEvent(socket.current, navigate, () => params.current); + while (event_queue.length > 0) { + await processEvent(socket.current, navigate, params); } }); @@ -650,12 +627,10 @@ export const connect = async ( }, 200 * n_connect_errors); // Incremental backoff }); - // When the socket disconnects reset the event_processing flag socket.current.on("disconnect", (reason, details) => { socket.current.wait_connect = false; const try_reconnect = reason !== "io server disconnect" && reason !== "io client disconnect"; - event_processing = false; window.removeEventListener("unload", disconnectTrigger); window.removeEventListener("beforeunload", disconnectTrigger); window.removeEventListener("pagehide", pagehideHandler); @@ -667,30 +642,24 @@ export const connect = async ( // On each received message, queue the updates and events. socket.current.on("event", async (update) => { - for (const substate in update.delta) { - dispatch[substate](update.delta[substate]); - // handle events waiting for `is_hydrated` - if ( - substate === state_name && - update.delta[substate]?.is_hydrated_rx_state_ - ) { - queueEvents(on_hydrated_queue, socket, false, navigate, params); - on_hydrated_queue.length = 0; + if (update.delta && Object.keys(update.delta).length > 0) { + for (const substate in update.delta) { + dispatch[substate](update.delta[substate]); + // handle events waiting for `is_hydrated` + if ( + substate === state_name && + update.delta[substate]?.is_hydrated_rx_state_ + ) { + queueEvents(on_hydrated_queue, socket, false, navigate, params); + on_hydrated_queue.length = 0; + } } + applyClientStorageDelta(client_storage, update.delta); } - applyClientStorageDelta(client_storage, update.delta); - if (update.final !== null) { - event_processing = !update.final; - } - if (update.events) { + if (update.events && update.events.length > 0) { queueEvents(update.events, socket, false, navigate, params); } }); - socket.current.on("reload", async (event) => { - event_processing = false; - on_hydrated_queue.push(event); - queueEvents(initialEvents(), socket, true, navigate, params); - }); socket.current.on("new_token", async (new_token) => { token = new_token; window.sessionStorage.setItem(TOKEN_KEY, new_token); @@ -713,7 +682,17 @@ export const ReflexEvent = ( event_actions = {}, handler = null, ) => { - return { name, payload, handler, event_actions }; + const e = { name }; + if (payload && Object.keys(payload).length > 0) { + e.payload = payload; + } + if (event_actions && Object.keys(event_actions).length > 0) { + e.event_actions = event_actions; + } + if (handler !== null) { + e.handler = handler; + } + return e; }; /** @@ -919,7 +898,7 @@ export const useEventLoop = ( setConnectErrors, client_storage, navigate, - () => params.current, + params, ); } }, [ @@ -947,7 +926,7 @@ export const useEventLoop = ( } return applyEventActions( - () => queueEvents(_events, socket, false, navigate, () => params.current), + () => queueEvents(_events, socket, false, navigate, params), event_actions, args, _events.map((e) => e.name).join("+++"), @@ -958,13 +937,7 @@ export const useEventLoop = ( const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode useEffect(() => { if (!sentHydrate.current) { - queueEvents( - initial_events(), - socket, - true, - navigate, - () => params.current, - ); + queueEvents(initial_events(), socket, true, navigate, params); sentHydrate.current = true; } }, []); @@ -1028,9 +1001,9 @@ export const useEventLoop = ( } (async () => { // Process all outstanding events. - while (event_queue.length > 0 && !event_processing) { + while (event_queue.length > 0) { await ensureSocketConnected(); - await processEvent(socket.current, navigate, () => params.current); + await processEvent(socket.current, navigate, params); } })(); }); diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index b83f213196a..d7903c5c9e2 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -36,7 +36,7 @@ PageNames, ) from reflex_base.constants.compiler import SpecialAttributes -from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER, FRONTEND_EVENT_STATE +from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_base.event import ( EventCallback, EventChain, @@ -888,7 +888,7 @@ def _post_init(self, *args, **kwargs): # Get the passed type and the var type. passed_type = kwargs[key]._var_type - expected_type = types.get_args( + expected_type = typing.get_args( types.get_field_type(type(self), key) )[0] except TypeError: @@ -1523,10 +1523,7 @@ def _event_trigger_values_use_state(self) -> bool: if isinstance(event, EventCallback): continue if isinstance(event, EventSpec): - if ( - event.handler.state_full_name - and event.handler.state_full_name != FRONTEND_EVENT_STATE - ): + if event.handler.state is not None: return True else: if event._var_state: @@ -2389,9 +2386,6 @@ class StatefulComponent(BaseComponent): was created with. """ - # A lookup table to caching memoized component instances. - tag_to_stateful_component: ClassVar[dict[str, StatefulComponent]] = {} - # Reference to the original component that was memoized into this component. component: Component = field( default_factory=Component, is_javascript_property=False @@ -2425,6 +2419,8 @@ def create(cls, component: Component) -> StatefulComponent | None: """ from reflex_components_core.core.foreach import Foreach + from reflex_base.registry import RegistrationContext + if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. return None @@ -2469,11 +2465,12 @@ def create(cls, component: Component) -> StatefulComponent | None: return None # Look up the tag in the cache - stateful_component = cls.tag_to_stateful_component.get(tag_name) + ctx = RegistrationContext.get() + stateful_component = ctx.tag_to_stateful_component.get(tag_name) if stateful_component is None: memo_trigger_hooks = cls._fix_event_triggers(component) # Set the stateful component in the cache for the given tag. - stateful_component = cls.tag_to_stateful_component.setdefault( + stateful_component = ctx.tag_to_stateful_component.setdefault( tag_name, cls( children=component.children, diff --git a/packages/reflex-base/src/reflex_base/constants/state.py b/packages/reflex-base/src/reflex_base/constants/state.py index 3f6ebec2f17..8742f76e185 100644 --- a/packages/reflex-base/src/reflex_base/constants/state.py +++ b/packages/reflex-base/src/reflex_base/constants/state.py @@ -11,9 +11,6 @@ class StateManagerMode(str, Enum): REDIS = "redis" -# Used for things like console_log, etc. -FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state" - FIELD_MARKER = "_rx_state_" MEMO_MARKER = "_rx_memo_" CAMEL_CASE_MEMO_MARKER = "RxMemo" diff --git a/packages/reflex-base/src/reflex_base/context/__init__.py b/packages/reflex-base/src/reflex_base/context/__init__.py new file mode 100644 index 00000000000..8c279f0bb08 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/context/__init__.py @@ -0,0 +1 @@ +"""Internal ContextVar and registration helpers for reflex.""" diff --git a/packages/reflex-base/src/reflex_base/context/base.py b/packages/reflex-base/src/reflex_base/context/base.py new file mode 100644 index 00000000000..7bb28d4864c --- /dev/null +++ b/packages/reflex-base/src/reflex_base/context/base.py @@ -0,0 +1,81 @@ +"""Shared contextvars wrapper for contextual globals.""" + +from __future__ import annotations + +import dataclasses +from contextvars import ContextVar, Token +from typing import ClassVar + +from typing_extensions import Self + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BaseContext: + """Base context class that acts as an async context manager to set the context var.""" + + _context_var: ClassVar[ContextVar[Self]] + _attached_context_token: ClassVar[dict[Self, Token[Self]]] + + @classmethod + def __init_subclass__(cls, **kwargs): + """Initialize the context variable for the subclass.""" + super(BaseContext, cls).__init_subclass__(**kwargs) + cls._context_var = ContextVar(cls.__name__) + cls._attached_context_token = {} + + @classmethod + def get(cls) -> Self: + """Get the context from the context variable. + + Returns: + The context instance. + """ + return cls._context_var.get() + + @classmethod + def set(cls, context: Self) -> Token[Self]: + """Set the context in the context variable. + + Args: + context: The context instance to set. + + Returns: + The token for resetting the context variable. + """ + return cls._context_var.set(context) + + @classmethod + def reset(cls, token: Token[Self]) -> None: + """Reset the context variable to a previous state. + + Args: + token: The token to reset the context variable to. + """ + cls._context_var.reset(token) + + def __enter__(self) -> Self: + """Enter the context. + + Returns: + This context instance. + """ + if self._attached_context_token.get(self) is not None: + msg = "Context is already attached, cannot enter context manager." + raise RuntimeError(msg) + self._attached_context_token[self] = self._context_var.set(self) + return self + + def __exit__(self, *exc_info): + """Exit the context.""" + if (token := self._attached_context_token.pop(self)) is not None: + self._context_var.reset(token) + + def ensure_context_attached(self): + """Ensure that the context is attached to the current context variable. + + Raises: + RuntimeError: If the context is not attached. + """ + if self._attached_context_token.get(self) is None: + msg = f"{type(self).__name__} must be entered before calling this method." + raise RuntimeError(msg) diff --git a/packages/reflex-base/src/reflex_base/event.py b/packages/reflex-base/src/reflex_base/event/__init__.py similarity index 93% rename from packages/reflex-base/src/reflex_base/event.py rename to packages/reflex-base/src/reflex_base/event/__init__.py index 45e068ac951..ca0347745f9 100644 --- a/packages/reflex-base/src/reflex_base/event.py +++ b/packages/reflex-base/src/reflex_base/event/__init__.py @@ -28,7 +28,6 @@ from reflex_base import constants from reflex_base.components.field import BaseField from reflex_base.constants.compiler import CompileVars, Hooks, Imports -from reflex_base.constants.state import FRONTEND_EVENT_STATE from reflex_base.utils import format from reflex_base.utils.decorator import once from reflex_base.utils.exceptions import ( @@ -56,6 +55,9 @@ ) from reflex_base.vars.object import ObjectVar +if TYPE_CHECKING: + from reflex.state import BaseState + @dataclasses.dataclass( init=True, @@ -65,14 +67,11 @@ class Event: """An event that describes any state change in the app. Attributes: - token: The token to specify the client that the event is for. name: The event name. router_data: The routing data where event occurred. payload: The event payload. """ - token: str - name: str router_data: dict[str, Any] = dataclasses.field(default_factory=dict) @@ -80,14 +79,71 @@ class Event: payload: dict[str, Any] = dataclasses.field(default_factory=dict) @property - def substate_token(self) -> str: - """Get the substate token for the event. + def state_cls(self) -> "type[BaseState]": + """The state class for the event.""" + from reflex_base.registry import RegistrationContext + + substate_name = self.name.rpartition(".")[0] + return RegistrationContext.get().base_states[substate_name] + + @classmethod + def from_event_type( + cls, + events: "IndividualEventType | list[IndividualEventType] | None", + *, + router_data: dict[str, Any] | None = None, + ) -> "list[Event]": + """Create a list of Events from event-like objects. + + Args: + events: The event-like objects to create Events from. + router_data: The routing data for the events. Returns: - The substate token. + A list of Events created from the event-like objects. """ - substate = self.name.rpartition(".")[0] - return f"{self.token}_{substate}" + # If the event handler returns nothing, return an empty list. + if events is None: + return [] + + # If the handler returns a single event, wrap it in a list. + if not isinstance(events, list): + events = [events] + + # Fix the events created by the handler. + out = [] + for e in events: + if callable(e) and getattr(e, "__name__", "") == "": + # A lambda was returned, assume the user wants to call it with no args. + e = e() + if isinstance(e, Event): + # If the event is already an event, append it to the list. + if router_data is not None and e.router_data != router_data: + out.append( + dataclasses.replace(e, router_data=e.router_data | router_data) + ) + else: + out.append(e) + continue + # Otherwise, create an event from the event spec. + if isinstance(e, EventHandler): + e = e() + if not isinstance(e, EventSpec): + msg = f"Unexpected event type, {type(e)}." + raise ValueError(msg) + name = format.format_event_handler(e.handler) + payload = {k._js_expr: v._decode() for k, v in e.args} + + # Create an event and append it to the list. + out.append( + Event( + name=name, + payload=payload, + router_data=router_data or {}, + ) + ) + + return out _EVENT_FIELDS: set[str] = {f.name for f in dataclasses.fields(Event)} @@ -108,8 +164,8 @@ def _handler_name(handler: "EventHandler") -> str: Returns: The fully qualified handler name. """ - if handler.state_full_name: - return f"{handler.state_full_name}.{handler.fn.__name__}" + if handler.state is not None: + return f"{handler.state.get_full_name()}.{handler.fn.__name__}" return handler.fn.__qualname__ @@ -278,12 +334,21 @@ class EventHandler(EventActionsMixin): Attributes: fn: The function to call in response to the event. - state_full_name: The full name of the state class this event handler is attached to. Empty string means this event handler is a server side event. + state: The state this EventHandler is directly attached to, if any. """ fn: Any = dataclasses.field(default=None) - state_full_name: str = dataclasses.field(default="") + state: "type[BaseState] | None" = dataclasses.field(default=None, repr=False) + + @property + def state_full_name(self) -> str: + """Get the full name of the state class this event handler is attached to. + + Returns: + The full name of the state class this event handler is attached to. + """ + return self.state.get_full_name() if self.state else "" def __hash__(self): """Get the hash of the event handler. @@ -291,7 +356,7 @@ def __hash__(self): Returns: The hash of the event handler. """ - return hash((tuple(self.event_actions.items()), self.fn, self.state_full_name)) + return hash((tuple(self.event_actions.items()), self.fn, self.state)) def get_parameters(self) -> Mapping[str, inspect.Parameter]: """Get the parameters of the function. @@ -354,7 +419,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": from reflex_base.utils.exceptions import EventHandlerTypeError # Get the function args. - fn_args = list(self._parameters)[1:] + if self.state is not None: + # Skip the `self` arg for state-bound event handlers. + fn_args = list(self._parameters)[1:] + else: + fn_args = list(self._parameters) if not isinstance( repeated_arg := next( @@ -476,8 +545,10 @@ def add_args(self, *args: Var) -> "EventSpec": """ from reflex_base.utils.exceptions import EventHandlerTypeError + n_self_args = 1 if self.handler.state is not None else 0 + # Get the remaining unfilled function args. - fn_args = list(self.handler._parameters)[1 + len(self.args) :] + fn_args = list(self.handler._parameters)[n_self_args + len(self.args) :] fn_args = (Var(_js_expr=arg) for arg in fn_args) # Construct the payload. @@ -1134,7 +1205,7 @@ def fn(): fn.__qualname__ = name fn.__signature__ = sig # pyright: ignore [reportFunctionMemberAccess] return EventSpec( - handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE), + handler=EventHandler(fn=fn), args=tuple( ( Var(_js_expr=k), @@ -1183,7 +1254,7 @@ def redirect( """ return server_side( "_redirect", - get_fn_signature(redirect), + inspect.signature(redirect), path=path, external=is_external, popup=popup, @@ -1247,7 +1318,7 @@ def set_focus(ref: str) -> EventSpec: """ return server_side( "_set_focus", - get_fn_signature(set_focus), + inspect.signature(set_focus), ref=LiteralVar.create(format.format_ref(ref)), ) @@ -1263,7 +1334,7 @@ def blur_focus(ref: str) -> EventSpec: """ return server_side( "_blur_focus", - get_fn_signature(blur_focus), + inspect.signature(blur_focus), ref=LiteralVar.create(format.format_ref(ref)), ) @@ -1301,7 +1372,7 @@ def set_value(ref: str, value: Any) -> EventSpec: """ return server_side( "_set_value", - get_fn_signature(set_value), + inspect.signature(set_value), ref=LiteralVar.create(format.format_ref(ref)), value=value, ) @@ -1321,7 +1392,7 @@ def remove_cookie(key: str, options: dict[str, Any] | None = None) -> EventSpec: options["path"] = options.get("path", "/") return server_side( "_remove_cookie", - get_fn_signature(remove_cookie), + inspect.signature(remove_cookie), key=key, options=options, ) @@ -1335,7 +1406,7 @@ def clear_local_storage() -> EventSpec: """ return server_side( "_clear_local_storage", - get_fn_signature(clear_local_storage), + inspect.signature(clear_local_storage), ) @@ -1350,7 +1421,7 @@ def remove_local_storage(key: str) -> EventSpec: """ return server_side( "_remove_local_storage", - get_fn_signature(remove_local_storage), + inspect.signature(remove_local_storage), key=key, ) @@ -1363,7 +1434,7 @@ def clear_session_storage() -> EventSpec: """ return server_side( "_clear_session_storage", - get_fn_signature(clear_session_storage), + inspect.signature(clear_session_storage), ) @@ -1378,7 +1449,7 @@ def remove_session_storage(key: str) -> EventSpec: """ return server_side( "_remove_session_storage", - get_fn_signature(remove_session_storage), + inspect.signature(remove_session_storage), key=key, ) @@ -1475,7 +1546,7 @@ def download( return server_side( "_download", - get_fn_signature(download), + inspect.signature(download), url=url, filename=filename, ) @@ -1516,7 +1587,7 @@ def call_script( return server_side( "_call_script", - get_fn_signature(call_script), + inspect.signature(call_script), javascript_code=javascript_code, **callback_kwargs, ) @@ -1552,7 +1623,7 @@ def call_function( return server_side( "_call_function", - get_fn_signature(call_function), + inspect.signature(call_function), function=javascript_code, **callback_kwargs, ) @@ -1735,13 +1806,14 @@ def call_event_handler( if isinstance(event_callback, EventSpec): parameters = event_callback.handler._parameters + n_self_args = 1 if event_callback.handler.state is not None else 0 check_fn_match_arg_spec( event_callback.handler.fn, parameters, event_spec_args, key, - bool(event_callback.handler.state_full_name) + len(event_callback.args), + n_self_args + len(event_callback.args), event_callback.handler.fn.__qualname__, ) @@ -1757,9 +1829,7 @@ def call_event_handler( _check_event_args_subclass_of_callback( [ arg - for arg in event_callback_spec_args[ - bool(event_callback.handler.state_full_name) : - ] + for arg in event_callback_spec_args[n_self_args:] if arg not in argument_names ], event_spec_return_types, @@ -1771,6 +1841,8 @@ def call_event_handler( # Handle partial application of EventSpec args return event_callback.add_args(*event_spec_args) + n_self_args = 1 if event_callback.state is not None else 0 + parameters = event_callback._parameters check_fn_match_arg_spec( @@ -1778,7 +1850,7 @@ def call_event_handler( parameters, event_spec_args, key, - bool(event_callback.state_full_name), + n_self_args, event_callback.fn.__qualname__, ) @@ -1791,7 +1863,7 @@ def call_event_handler( type_hints_of_provided_callback = {} _check_event_args_subclass_of_callback( - event_callback_spec_args[1:], + event_callback_spec_args[n_self_args:], event_spec_return_types, type_hints_of_provided_callback, event_callback.fn.__qualname__, @@ -2030,14 +2102,16 @@ def get_handler_args( def fix_events( events: list[EventSpec | EventHandler] | None, - token: str, + token: str | None = None, router_data: dict[str, Any] | None = None, ) -> list[Event]: """Fix a list of events returned by an event handler. + Deprecated: use Event.from_event_type instead. + Args: events: The events to fix. - token: The user token. + token: Deprecated, ignored. Kept for backward compatibility. router_data: The optional router data to set in the event. Returns: @@ -2046,6 +2120,14 @@ def fix_events( Raises: ValueError: If the event type is not what was expected. """ + from reflex_base.utils.console import deprecate + + deprecate( + feature_name="rx.event.fix_events()", + reason="Use Event.from_event_type() instead", + deprecation_version="0.9.0", + removal_version="1.0", + ) # If the event handler returns nothing, return an empty list. if events is None: return [] @@ -2084,7 +2166,6 @@ def fix_events( # Create an event and append it to the list. out.append( Event( - token=token, name=name, payload=payload, router_data=event_router_data, @@ -2094,22 +2175,6 @@ def fix_events( return out -def get_fn_signature(fn: Callable) -> inspect.Signature: - """Get the signature of a function. - - Args: - fn: The function. - - Returns: - The signature of the function. - """ - signature = inspect.signature(fn) - new_param = inspect.Parameter( - FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any - ) - return signature.replace(parameters=(new_param, *signature.parameters.values())) - - # These chains can be used for their side effects when no other events are desired. stop_propagation = noop().stop_propagation prevent_default = noop().prevent_default @@ -2713,7 +2778,6 @@ def wrapper( parse_args_spec = staticmethod(parse_args_spec) args_specs_from_fields = staticmethod(args_specs_from_fields) unwrap_var_annotation = staticmethod(unwrap_var_annotation) - get_fn_signature = staticmethod(get_fn_signature) # Event Spec Functions passthrough_event_spec = staticmethod(passthrough_event_spec) @@ -2750,7 +2814,26 @@ def wrapper( run_script = staticmethod(run_script) __file__ = __file__ + @property + def BaseState(self) -> "type[BaseState]": # noqa: N802 + """Get the BaseState class. + + A reference to BaseState is needed for doc generation when resolving + type hints, so add it to the namespace late to avoid circular import + issues. + + Returns: + The BaseState class. + """ + from reflex.state import BaseState + + return BaseState + event = EventNamespace event.event = event # pyright: ignore[reportAttributeAccessIssue] +_this = sys.modules[__name__] +event.__path__ = _this.__path__ # pyright: ignore[reportAttributeAccessIssue] +event.__spec__ = _this.__spec__ # pyright: ignore[reportAttributeAccessIssue] +event.__package__ = _this.__package__ # pyright: ignore[reportAttributeAccessIssue] sys.modules[__name__] = event # pyright: ignore[reportArgumentType] diff --git a/packages/reflex-base/src/reflex_base/event/context.py b/packages/reflex-base/src/reflex_base/event/context.py new file mode 100644 index 00000000000..df3f0200e62 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/context.py @@ -0,0 +1,146 @@ +"""The context and associated metadata for handling an event.""" + +from __future__ import annotations + +import dataclasses +import functools +import uuid +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, Protocol + +from reflex_base.context.base import BaseContext +from reflex_base.utils.format import to_snake_case + +if TYPE_CHECKING: + from reflex.istate.manager import StateManager + from reflex_base.event import Event + + +@functools.lru_cache +def get_name(cls: type | Callable) -> str: + """Get the name of the state/func. + + Returns: + The name of the state/func. + """ + module = cls.__module__.replace(".", "___") + qualname = getattr(cls, "__qualname__", cls.__name__).replace(".", "___") + return to_snake_case(f"{module}___{qualname}") + + +class EnqueueProtocol(Protocol): + """Protocol for the enqueue function in the event context.""" + + async def __call__(self, token: str, *events: Event) -> Any: + """Enqueue an event handler to be executed. + + Args: + token: The client token associated with the event. + events: The events to enqueue. + """ + ... + + +class EmitEventProtocol(Protocol): + """Protocol for the emit_event function in the event context.""" + + async def __call__(self, token: str, *events: Event) -> Any: + """Emit an event to be processed immediately. + + Args: + token: The client token associated with the event. + events: The events to emit. + """ + ... + + +class EmitDeltaProtocol(Protocol): + """Protocol for the emit_delta function in the event context.""" + + async def __call__( + self, + token: str, + delta: Mapping[str, Mapping[str, Any]], + ) -> Any: + """Emit a delta to the frontend. + + Args: + token: The client token to emit the delta to. + delta: The deltas to emit, mapping client tokens to variable updates. + """ + ... + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) +class EventContext(BaseContext): + """The context for an event.""" + + # Identifies the client session. + token: str + + # Manages persistence of state across events. + state_manager: StateManager = dataclasses.field(repr=False) + + # Function responsible for enqueuing an event handler to be executed. + enqueue_impl: EnqueueProtocol = dataclasses.field(repr=False) + + # Each event is associated with a top-level transaction id. + txid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex[:12]) + # The txid of another EventContext that enqueued this context's event. + parent_txid: str | None = None + + emit_delta_impl: EmitDeltaProtocol | None = dataclasses.field( + default=None, repr=False + ) + emit_event_impl: EmitEventProtocol | None = dataclasses.field( + default=None, repr=False + ) + cached_states: dict[type, Any] = dataclasses.field( + default_factory=dict, init=False, repr=False + ) + + def fork(self, token: str | None = None) -> EventContext: + """Return a new EventContext with the specified fields replaced. + + Args: + token: The client token for the new context. + + Returns: + A new EventContext with the specified fields replaced. + """ + return type(self)( + token=token or self.token, + parent_txid=self.txid, + state_manager=self.state_manager, + enqueue_impl=self.enqueue_impl, + emit_delta_impl=self.emit_delta_impl, + emit_event_impl=self.emit_event_impl, + ) + + async def emit_delta(self, delta: Mapping[str, Mapping[str, Any]]) -> None: + """Emit a delta to the frontend. + + Args: + delta: The deltas to emit, mapping client tokens to variable updates. + """ + if self.emit_delta_impl is not None: + await self.emit_delta_impl(self.token, delta) + + async def emit_event(self, *events: Event) -> None: + """Emit an event to be processed on the frontend. + + If no such handler exists, the event will not be processed. + + Args: + events: The events to emit. + """ + if self.emit_event_impl is not None: + await self.emit_event_impl(self.token, *events) + + async def enqueue(self, *event: Event) -> None: + """Enqueue an event handler to be executed. + + Args: + event: The event to enqueue. + """ + await self.enqueue_impl(self.token, *event) diff --git a/packages/reflex-base/src/reflex_base/event/processor/__init__.py b/packages/reflex-base/src/reflex_base/event/processor/__init__.py new file mode 100644 index 00000000000..f72058483d2 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/__init__.py @@ -0,0 +1,14 @@ +"""Procedures for handling events.""" + +from reflex_base.event.processor.base_state_processor import BaseStateEventProcessor +from reflex_base.event.processor.event_processor import EventProcessor, EventQueueEntry +from reflex_base.event.processor.future import EventFuture +from reflex_base.event.processor.timeout import DrainTimeoutManager + +__all__ = [ + "BaseStateEventProcessor", + "DrainTimeoutManager", + "EventFuture", + "EventProcessor", + "EventQueueEntry", +] diff --git a/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py new file mode 100644 index 00000000000..30e7c9fe315 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py @@ -0,0 +1,401 @@ +"""Functions for processing BaseState-derived event handlers.""" + +from __future__ import annotations + +import dataclasses +import functools +import inspect +import warnings +from collections.abc import Mapping, Sequence +from enum import Enum +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any + +from reflex.istate.data import RouterData +from reflex.istate.manager.token import BaseStateToken +from reflex.istate.proxy import StateProxy +from reflex.utils import console, types +from reflex_base.event.context import EventContext +from reflex_base.event.processor.event_processor import EventProcessor, EventQueueEntry +from reflex_base.registry import RegisteredEventHandler +from reflex_base.utils.format import format_event_handler + +if TYPE_CHECKING: + from reflex.event import EventHandler, EventSpec + from reflex.state import BaseState + + +@functools.lru_cache(maxsize=1) +def _hydrate_event_name(): + from reflex.state import State + + return format_event_handler(State.event_handlers["hydrate"]) + + +def _check_valid_yield(events: Any, handler_name: str = "unknown") -> Any: + """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. + + Args: + events: The events to be checked. + handler_name: The name of the handler that yielded the events, used for error messages. + + Returns: + The events as they are if valid. + + Raises: + TypeError: If any of the events are not valid. + """ + from reflex.event import Event, EventHandler, EventSpec + + def _is_valid_type(events: Any) -> bool: + return isinstance(events, (Event, EventHandler, EventSpec)) + + if events is None or _is_valid_type(events): + return events + + if not (isinstance(events, Sequence) and not isinstance(events, (str, bytes))): + events = [events] + + try: + if all(_is_valid_type(e) for e in events): + return events + except TypeError: + pass + + coroutines = [e for e in events if inspect.iscoroutine(e)] + + for coroutine in coroutines: + coroutine_name = coroutine.__qualname__ + warnings.filterwarnings( + "ignore", message=f"coroutine '{coroutine_name}' was never awaited" + ) + + msg = ( + f"Your handler {handler_name} must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references)." + f" Returned events of types {', '.join(map(str, map(type, events)))!s}." + ) + raise TypeError(msg) + + +def _transform_event_arg(value: Any, hinted_args: Any) -> Any: + """Transform an event argument based on its type hint. + + Args: + value: The value to transform. + hinted_args: The type hint for the argument. + + Returns: + The transformed value. + + Raises: + ValueError: If a string value is received for an int or float type and cannot be converted. + """ + from reflex.model import Model + from reflex.utils.serializers import deserializers + + if hinted_args is Any: + return value + if types.is_union(hinted_args): + if value is None: + return value + hinted_args = types.value_inside_optional(hinted_args) + if ( + isinstance(value, dict) + and isinstance(hinted_args, type) + and not types.is_generic_alias(hinted_args) # py3.10 + ): + if issubclass(hinted_args, Model): + # Remove non-fields from the payload + return hinted_args(**{ + key: value + for key, value in value.items() + if key in hinted_args.__fields__ + }) + if dataclasses.is_dataclass(hinted_args): + return hinted_args(**value) + if find_spec("pydantic"): + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + if issubclass(hinted_args, BaseModelV1): + return hinted_args.parse_obj(value) + if issubclass(hinted_args, BaseModelV2): + return hinted_args.model_validate(value) + if isinstance(value, list) and (hinted_args is set or hinted_args is frozenset): + return set(value) + if isinstance(value, list) and hinted_args is tuple: + return tuple(value) + if isinstance(hinted_args, type) and issubclass(hinted_args, Enum): + try: + return hinted_args(value) + except ValueError: + msg = f"Received an invalid enum value ({value}) for type {hinted_args}" + raise ValueError(msg) from None + if ( + isinstance(value, str) + and (deserializer := deserializers.get(hinted_args)) is not None + ): + try: + return deserializer(value) + except ValueError: + msg = f"Received a string value ({value}) but expected a {hinted_args}" + raise ValueError(msg) from None + return value + + +def _transform_event_payload( + payload: Mapping[str, Any], type_hints: Mapping[str, Any] +) -> dict[str, Any]: + """Transform an event payload based on the type hints of the handler. + + Args: + payload: The event payload to transform. + type_hints: The type hints for the handler's arguments. + + Returns: + The transformed event payload. + """ + transformed = {} + for arg, value in list(payload.items()): + hinted_args = type_hints.get(arg, Any) + try: + transformed[arg] = _transform_event_arg(value, hinted_args) + except Exception as ex: + msg = f"Error transforming event argument '{arg}' with value '{value}' and type hint '{hinted_args}'" + raise ValueError(msg) from ex + return transformed + + +async def chain_updates( + events: EventSpec | list[EventSpec] | None, + handler_name: str, + root_state: BaseState | None = None, +) -> None: + """Chain yielded events and emit a delta to the frontend. + + Check for validitity and convert the EventSpec into qualified Event objects + to be queued against the current EventContext. + + Args: + events: The events to queue with the update. + handler_name: The name of the handler that yielded the events, used for error messages. + root_state: The root state of the app, no delta emitted if omitted. + """ + from reflex.event import Event + + ctx = EventContext.get() + + if root_state is not None: + # Emit deltas first, so any frontend events are processed with the latest state. + try: + delta = await root_state._get_resolved_delta() + if delta: + await ctx.emit_delta(delta) + finally: + root_state._clean() + + # Convert valid EventHandler and EventSpec into Event + if fixed_events := Event.from_event_type( + _check_valid_yield(events, handler_name=handler_name), + ): + # Frontend events. + if frontend_events := [e for e in fixed_events if e.name.startswith("_")]: + await ctx.emit_event(*frontend_events) + # Backend events. + await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_"))) + + +async def process_event( + handler: EventHandler, + payload: dict, + state: BaseState | StateProxy, + root_state: BaseState, +): + """Process event. + + Args: + handler: EventHandler to process. + payload: The event payload. + state: State to process the handler. + root_state: The root state of the app, used for emitting deltas. + + Raises: + ValueError: If a string value is received for an int or float type and cannot be converted. + """ + handler_name = handler.fn.__qualname__ + + # Get the function to process the event. + fn = functools.partial(handler.fn, state) + + try: + type_hints = types.get_type_hints(handler.fn) + payload = _transform_event_payload(payload, type_hints) + except Exception as ex: + # No transformation was possible, continue with the original payload + console.warn( + f"Error transforming event payload for handler {handler_name}: {ex}" + ) + + # Handle async functions. + if inspect.iscoroutinefunction(fn.func): + events = await fn(**payload) + + # Handle regular functions. + else: + events = fn(**payload) + # Handle async generators. + if inspect.isasyncgen(events): + async for event in events: + await chain_updates(event, root_state=root_state, handler_name=handler_name) + await chain_updates(None, root_state=root_state, handler_name=handler_name) + + # Handle regular generators. + elif inspect.isgenerator(events): + try: + while True: + await chain_updates( + next(events), root_state=root_state, handler_name=handler_name + ) + except StopIteration as si: + # the "return" value of the generator is not available + # in the loop, we must catch StopIteration to access it + if si.value is not None: + await chain_updates( + si.value, root_state=root_state, handler_name=handler_name + ) + await chain_updates(None, root_state=root_state, handler_name=handler_name) + + # Handle regular event chains. + else: + await chain_updates(events, root_state=root_state, handler_name=handler_name) + + +class BaseStateEventProcessor(EventProcessor): + """Event processor for BaseState-derived states. + + This processor is used to process events for BaseState-derived states, and + is responsible for maintaining the event queue and emitting deltas to the + frontend. + """ + + async def _rehydrate(self, root_state: BaseState): + """Rehydrate the state by calling the hydrate event handler. + + Args: + root_state: The root state to rehydrate. + """ + from reflex.state import OnLoadInternalState, State + + if ( + type(root_state) is not State + or OnLoadInternalState.get_name() not in root_state.substates + ): + return + + await process_event( + handler=State.event_handlers["hydrate"], + payload={}, + state=root_state, + root_state=root_state, + ) + await process_event( + handler=OnLoadInternalState.event_handlers["on_load_internal"], + payload={}, + state=await root_state.get_state(OnLoadInternalState), + root_state=root_state, + ) + + async def _execute_event( + self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + ) -> None: + """Execute the handler for a single event queue entry with full state management. + + The ``EventContext`` has already been set by ``_process_event_queue_entry`` + before this method is called. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + """ + ctx = entry.ctx + event = entry.event + router_data = event.router_data or {} + # Get the state for the session exclusively. + async with ctx.state_manager.modify_state_with_links( + BaseStateToken( + ident=ctx.token, + cls=registered_handler.states[0], + ), + event=entry.event, + ) as state: + # Compatibility hack rehydrate the state before processing this event. + needs_to_rehydrate = bool( + not state.router_data and event.name != _hydrate_event_name() + ) + + # re-assign only when the value is set and different + if router_data and state.router_data != router_data: + # assignment will recurse into substates and force recalculation of + # dependent ComputedVar (dynamic route variables) + state.router_data = router_data + if state.router != (router := RouterData.from_router_data(router_data)): + state.router = router + + # Preprocess the event. + if ( + self.middleware is not None + and (update := await self.middleware._preprocess(state, event)) + is not None + ): + # If there was an update, yield it. + if update.delta: + await ctx.emit_delta(update.delta) + if update.events: + await ctx.enqueue(*update.events) + return + + # Get the event's substate. + substate = await state.get_state(event.state_cls) + root_state = state._get_root_state() + + if needs_to_rehydrate: + await self._rehydrate(root_state) + + # Process non-background events while holding the lock. + if not registered_handler.handler.is_background: + await process_event( + handler=registered_handler.handler, + payload=event.payload, + state=substate, + root_state=root_state, + ) + return + # Otherwise drop the state lock and start processing the background task with a proxy state. + await process_event( + handler=registered_handler.handler, + state=StateProxy(substate), + payload=event.payload, + root_state=root_state, + ) + + async def _handle_backend_exception( + self, ex: Exception, ev_ctx: EventContext | None = None + ) -> None: + """Handle an exception raised during event processing by calling the backend exception handler if it exists. + + Args: + ex: The exception that was raised. + ev_ctx: The event context for the exception. + """ + if self.backend_exception_handler is not None: + if ev_ctx is not None: + # Ensure the event context is set for the exception handler. + EventContext.set(ev_ctx) + if events := self.backend_exception_handler(ex): + await chain_updates( + events=events, + handler_name=self.backend_exception_handler.__qualname__, + ) + + +__all__ = ["BaseStateEventProcessor", "chain_updates", "process_event"] diff --git a/packages/reflex-base/src/reflex_base/event/processor/compat.py b/packages/reflex-base/src/reflex_base/event/processor/compat.py new file mode 100644 index 00000000000..040e1aff475 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/compat.py @@ -0,0 +1,87 @@ +"""Compatibility shims since asyncio changes quite a bit from 3.11 to 3.14.""" + +import asyncio +import sys + +if sys.version_info >= (3, 13): + from asyncio import as_completed as as_completed +else: + # The following implementation of as_completed is adapted from Python 3.14 + # python/cpython@9e1f1644cd7b7661f0748bb37351836e8d6f37e2 + + class _AsCompletedIterator: + """Iterator of awaitables representing tasks of asyncio.as_completed. + + As an asynchronous iterator, iteration yields futures as they finish. As a + plain iterator, new coroutines are yielded that will return or raise the + result of the next underlying future to complete. + """ + + def __init__(self, aws, timeout): # noqa: ANN001 + self._done = asyncio.Queue() + self._timeout_handle = None + + loop = asyncio.get_event_loop() + todo = {asyncio.ensure_future(aw, loop=loop) for aw in set(aws)} + for f in todo: + f.add_done_callback(self._handle_completion) + if todo and timeout is not None: + self._timeout_handle = loop.call_later(timeout, self._handle_timeout) + self._todo = todo + self._todo_left = len(todo) + + def __aiter__(self): + return self + + def __iter__(self): + return self + + async def __anext__(self): + if not self._todo_left: + raise StopAsyncIteration + assert self._todo_left > 0 + self._todo_left -= 1 + return await self._wait_for_one() + + def __next__(self): + if not self._todo_left: + raise StopIteration + assert self._todo_left > 0 + self._todo_left -= 1 + return self._wait_for_one(resolve=True) + + def _handle_timeout(self): + for f in self._todo: + f.remove_done_callback(self._handle_completion) + self._done.put_nowait(None) # Sentinel for _wait_for_one(). + self._todo.clear() # Can't do todo.remove(f) in the loop. + + def _handle_completion(self, f): # noqa: ANN001 + if not self._todo: + return # _handle_timeout() was here first. + self._todo.remove(f) + self._done.put_nowait(f) + if not self._todo and self._timeout_handle is not None: + self._timeout_handle.cancel() + + async def _wait_for_one(self, resolve=False): # noqa: ANN001 + # Wait for the next future to be done and return it unless resolve is + # set, in which case return either the result of the future or raise + # an exception. + f = await self._done.get() + if f is None: + # Dummy value from _handle_timeout(). + raise asyncio.TimeoutError + return f.result() if resolve else f + + def as_completed(aws, *, timeout=None): # noqa: ANN001 + """Return an iterator of coroutines that yield the results of the given awaitables. + + The coroutines are ordered in the order in which the given awaitables complete. + If a given awaitable raises an exception, the corresponding coroutine raises the same exception. + + Args: + aws: An iterable of awaitables. + timeout: If provided, the maximum number of seconds to wait for the next awaitable to complete before raising a TimeoutError. + """ + return _AsCompletedIterator(aws, timeout) diff --git a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py new file mode 100644 index 00000000000..f22c4fb78ba --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -0,0 +1,734 @@ +"""Base EventProcessor class for handling backend event queue.""" + +from __future__ import annotations + +import asyncio +import collections +import contextlib +import dataclasses +import inspect +import sys +import time +import traceback +from collections.abc import AsyncGenerator, Callable, Mapping, Sequence +from contextvars import Token, copy_context +from typing import TYPE_CHECKING, Any + +import rich.markup +from typing_extensions import Self + +from reflex.app_mixins.middleware import MiddlewareMixin +from reflex.istate.manager import StateManager +from reflex.utils import console +from reflex_base.event.context import EventContext +from reflex_base.event.processor.compat import as_completed +from reflex_base.event.processor.future import EventFuture +from reflex_base.event.processor.timeout import DrainTimeoutManager +from reflex_base.registry import RegisteredEventHandler, RegistrationContext + +if TYPE_CHECKING: + from reflex.app import EventNamespace + from reflex.event import Event, EventSpec + +if hasattr(asyncio, "QueueShutDown"): + + class QueueShutDown(asyncio.QueueShutDown): # pyright: ignore[reportRedeclaration] + """Exception raised when trying to put an item into a shut down queue.""" + +else: + + class QueueShutDown(Exception): # noqa: N818 + """Exception raised when trying to put an item into a shut down queue.""" + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class EventQueueEntry: + """An entry in the event queue.""" + + event: Event + ctx: EventContext + + +@dataclasses.dataclass(kw_only=True, slots=True) +class EventProcessor: + """Responsible for queuing and processing events. + + Attributes: + middleware: An optional middleware mixin to apply to all events processed by this processor. + backend_exception_handler: An optional function to handle exceptions raised during event processing. The function should take an Exception as input and return an EventSpec or list of EventSpecs to be emitted in response, or None to not emit any events. + graceful_shutdown_timeout: An optional amount of time in seconds to wait for the queue to drain before forcefully cancelling tasks when stopping the processor. If None, the processor will not wait and will cancel tasks immediately. + + _queue: The asyncio queue for events to be processed. + _queue_task: The task responsible for processing the event queue. + _root_context: The root event context to use for events enqueued without an explicit context. + _attached_root_context_token: The context variable token for the attached root context, used to reset the context variable on shutdown. + _tasks: A mapping of active transaction ids to their corresponding event handler tasks, used for tracking and cancellation on shutdown. + """ + + middleware: MiddlewareMixin | None = None + backend_exception_handler: ( + Callable[[Exception], EventSpec | list[EventSpec] | None] | None + ) = None + graceful_shutdown_timeout: float | None = None + + _queue: asyncio.Queue[EventQueueEntry] | None = dataclasses.field( + default=None, init=False + ) + _queue_task: asyncio.Task | None = dataclasses.field(default=None, init=False) + _root_context: EventContext | None = dataclasses.field(default=None, init=False) + _attached_root_context_token: Token | None = dataclasses.field( + default=None, init=False + ) + _tasks: dict[str, asyncio.Task] = dataclasses.field( + default_factory=dict, init=False + ) + _futures: dict[str, EventFuture] = dataclasses.field( + default_factory=dict, init=False + ) + _token_queues: dict[ + str, + collections.deque[tuple[EventQueueEntry, RegisteredEventHandler]], + ] = dataclasses.field(default_factory=dict, init=False) + + def configure( + self, + *, + state_manager: StateManager | None = None, + event_namespace: EventNamespace | None = None, + ) -> Self: + """Set up the event processor. + + Before an event processor can be used, it must be configured with a + state manager and optionally an event namespace to communicate with the + frontend. + + Args: + state_manager: The state manager to use for processing events. + event_namespace: The event namespace to use for processing events. + + Returns: + The event processor instance. + """ + from reflex.istate.manager.memory import StateManagerMemory + from reflex.state import StateUpdate + + if self._root_context is not None: + msg = ( + "Event processor is already configured, call .configure(...) only once." + ) + raise RuntimeError(msg) + + emit_delta_impl = emit_event_impl = None + if event_namespace is not None: + + async def emit_delta( + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + """Emit a delta to the frontend. + + Args: + token: The client token to emit the delta to. + delta: The delta to emit, mapping client tokens to variable updates. + """ + await event_namespace.emit_update( + update=StateUpdate(delta=delta), + token=token, + ) + + emit_delta_impl = emit_delta + + async def emit_event(token: str, *events: Event) -> None: + """Emit an event to be processed on the frontend. + + If no such handler exists, the event will not be processed. + + Args: + token: The client token to emit the event to. + events: The events to emit. + """ + await event_namespace.emit_update( + update=StateUpdate(events=list(events)), + token=token, + ) + + emit_event_impl = emit_event + + if state_manager is None: + # For testing use cases, default to a new in-memory state manager if one is not provided. + state_manager = StateManagerMemory() + + self._root_context = EventContext( + token="", + parent_txid=None, + state_manager=state_manager, + enqueue_impl=self.enqueue_many, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, + ) + return self + + async def __aenter__(self) -> Self: + """Enter the event processor context manager. + + Returns: + The event processor instance. + """ + await self.start() + return self + + async def __aexit__(self, *exc_info) -> None: + """Exit the event processor context manager and stop the processor.""" + await self.stop() + + async def start(self) -> None: + """Start the event processor.""" + if self._root_context is None: + msg = "Event processor is not configured, call .configure(...) first." + raise RuntimeError(msg) + if self._queue is not None: + msg = "Event processor is already started" + raise RuntimeError(msg) + if self._attached_root_context_token is not None: + msg = "EventProcessor context cannot be nested." + raise RuntimeError(msg) + self._attached_root_context_token = EventContext.set(self._root_context) + self._queue = asyncio.Queue() + self._ensure_queue_task() + + async def _stop_tasks(self, timeout: float | None = None) -> None: + """Stop all running tasks with an optional drain time. + + Args: + timeout: An optional amount of time in seconds to wait for the + queue to drain before cancelling tasks. If None, the processor will + not wait and will cancel tasks immediately. + """ + finished_tasks = set() + # Graceful drain time, wait for tasks to finish and handle any exceptions. + if timeout is not None and self._tasks: + with contextlib.suppress(asyncio.TimeoutError): + async for task in as_completed(self._tasks.values(), timeout=timeout): + # Exceptions are handled in _finish_task and ignored here. + with contextlib.suppress(Exception): + await task + finished_tasks.add(task) + # Cancel all outstanding event handler tasks. + outstanding_tasks = [ + task for task in self._tasks.values() if task not in finished_tasks + ] + for task in outstanding_tasks: + task.cancel() + # Wait for all tasks to finish and log any exceptions that were raised. + for task in outstanding_tasks: + with contextlib.suppress(Exception, asyncio.CancelledError): + # Exceptions are handled in _finish_task. + await task + + async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: + """Stop the event processor and cancel all running tasks. + + Args: + graceful_shutdown_timeout: An optional amount of time in seconds to wait for the + queue to drain before cancelling tasks. If None, the processor will + not wait and will cancel tasks immediately. + """ + from reflex.utils import telemetry + + if self._attached_root_context_token is not None: + EventContext.reset(self._attached_root_context_token) + self._attached_root_context_token = None + # Optional grace period for tasks to finish before cancellation. + if graceful_shutdown_timeout is None: + graceful_shutdown_timeout = self.graceful_shutdown_timeout + drain_timeout = DrainTimeoutManager.with_timeout(graceful_shutdown_timeout) + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + if remaining_time > 0: + # Drain the queue first of any pending events. + await self.join(timeout=remaining_time) + # Stopping tasks may raise exceptions and chain additional deltas so the queue remains open. + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + await self._stop_tasks(timeout=remaining_time) + # Cancel queue processing now that all tasks have been cancelled. + queue = self._queue + if self._queue is not None: + if sys.version_info >= (3, 13): + self._queue.shutdown() + self._queue = None + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + if remaining_time > 0: + await self.join(timeout=remaining_time, queue=queue) + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + # Stop all tasks again now that the queue is shut down, no additional events can be queued. + await self._stop_tasks(timeout=remaining_time) + if self._queue_task is not None: + self._queue_task.cancel() + try: + await self._queue_task + except (asyncio.CancelledError, QueueShutDown, RuntimeError): + pass + except Exception as ex: + telemetry.send_error(ex, context="backend") + console.error( + rich.markup.escape( + f"Error in event processor queue task during shutdown:\n{traceback.format_exc()}" + ) + ) + self._queue_task = None + # Discard any pending per-token queue entries. + self._token_queues.clear() + # Cancel any remaining unresolved futures. + for future in self._futures.values(): + if not future.done(): + future.cancel() + self._futures.clear() + + async def join( + self, timeout: float | None = None, queue: asyncio.Queue | None = None + ) -> None: + """Wait for the event processor to finish processing all events in the queue. + + Args: + timeout: An optional amount of time in seconds to wait for the queue to + drain before returning. If None, this method will wait indefinitely + until the queue is fully drained. + queue: An optional queue to wait for instead of the processor's main + queue. This can be used to wait for a specific queue to drain, such + as when using a separate queue for testing. + """ + if queue is None: + queue = self._queue + if queue is not None: + await asyncio.wait_for(queue.join(), timeout=timeout) + + def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: + """Ensure the queue processing task is running. + + Returns: + The event queue. + + Raises: + RuntimeError: If the event processor is not running and no queue is provided. + """ + if self._root_context is None: + msg = "Event processor is not configured, call .configure(...) first." + raise RuntimeError(msg) + if self._queue is None: + msg = "Event processor is not running, call .start(...) first." + raise QueueShutDown(msg) + if self._queue_task is None: + task_context = copy_context() + task_context.run(EventContext.set, self._root_context) + self._queue_task = task_context.run( + asyncio.create_task, + self._process_queue(), + name=f"reflex_event_queue_processor|{time.time()}", + ) + return self._queue + + async def enqueue( + self, token: str, event: Event, ev_ctx: EventContext | None = None + ) -> EventFuture: + """Enqueue an event to be processed. + + Args: + token: The client token associated with the event. + event: The event to be enqueued. + ev_ctx: The event context to use for this event. + + Returns: + An EventFuture that resolves to the result of the associated task. + """ + if ev_ctx is None: + try: + ev_ctx = EventContext.get().fork(token=token) + except LookupError as le: + if self._root_context is not None: + ev_ctx = self._root_context.fork(token=token) + else: + msg = "Event processor is not running, call .start(...) first." + raise RuntimeError(msg) from le + queue = self._ensure_queue_task() + txid = ev_ctx.txid + parent_future = ( + self._futures.get(ev_ctx.parent_txid) + if ev_ctx.parent_txid is not None + else None + ) + tracked = EventFuture(parent=parent_future, txid=txid) + self._futures[txid] = tracked + tracked.add_done_callback(self._try_clean_future) + tracked.add_done_callback(self._on_future_done) + # If this context has a parent, register as a child of the parent's future. + if parent_future is not None: + parent_future.add_child(tracked) + await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) + return tracked + + async def enqueue_many(self, token: str, *events: Event) -> Sequence[EventFuture]: + """Enqueue multiple events to be processed. + + Args: + token: The client token associated with the events. + events: Remaining positional args are events to be enqueued. + + Returns: + A list of EventFutures corresponding to each enqueued event. + """ + return [await self.enqueue(token, event) for event in events] + + async def enqueue_stream_delta( + self, + token: str, + event: Event, + ) -> AsyncGenerator[Mapping[str, Any]]: + """Enqueue an event to be processed and yield deltas emitted by the event handler. + + Events queued by this method will not emit deltas to their target token in the typical way, instead + they will be yielded from this generator until the event handler finishes processing. + Deltas emitted for other tokens will be handled normally. + + Any frontend events or chained events are handled normally and deltas from chained events + will not be yielded by this method. + + Args: + token: The client token associated with the event. + event: The event to be enqueued. + + Yields: + Deltas emitted by the event handler for the specified token. + """ + if self._root_context is None: + msg = "Event processor is not configured, call .configure(...) first." + raise RuntimeError(msg) + + deltas = asyncio.Queue() + + async def _emit_delta_impl( + delta_token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + if ( + delta_token != token + and self._root_context is not None + and self._root_context.emit_delta_impl is not None + ): + # Emit deltas for other tokens normally. + await self._root_context.emit_delta_impl(delta_token, delta) + return + await deltas.put(delta) + + task_future = await self.enqueue( + token, + event, + ev_ctx=dataclasses.replace( + self._root_context, + token=token, + emit_delta_impl=_emit_delta_impl, + ), + ) + all_task_futures = asyncio.create_task(task_future.wait_all()) + waiting_for = {all_task_futures, asyncio.create_task(deltas.get())} + try: + while not all_task_futures.done() or not deltas.empty(): + with contextlib.suppress(asyncio.TimeoutError): + async for result in as_completed( + waiting_for, + timeout=1, + ): + waiting_for.remove(result) + if result is not all_task_futures: + yield await result + waiting_for.add(asyncio.create_task(deltas.get())) + break + finally: + for future in waiting_for: + future.cancel() + # Raise any exceptions for the caller, waiting for all chained events. + await task_future.wait_all() + + def _try_clean_future(self, future: EventFuture) -> None: # type: ignore[override] + """Pop a future from _futures when it and all immediate children are done. + + After popping, cascade the check upward: if the parent future is also + done and all its immediate children are done, pop the parent as well. + + This keeps parent futures alive in ``_futures`` while any child still + needs them for ``wait_all`` and cleanup. + + Args: + future: The EventFuture to check. + """ + if not future.done(): + return + # Not checking future.all_done() to avoid waiting for grandchildren here. + if not all(c.done() for c in future.children): + return + parent = future.parent + self._futures.pop(future.txid, None) + if parent is not None and parent.txid: + self._try_clean_future(parent) + + def _on_future_done(self, future: EventFuture) -> None: # type: ignore[override] + """Callback invoked when an enqueued future completes. + + If the future was cancelled externally, cancel the running task + and all child futures. If the task has not started yet, + ``_process_queue`` will check the future and skip it when the + entry is dequeued. + + Args: + future: The EventFuture that completed. + """ + if not future.cancelled(): + return + # Cascade cancellation to all child futures. + for child in future.children: + child.cancel() + task = self._tasks.get(future.txid) + if task is not None: + task.cancel() + + async def _execute_event( + self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + ) -> None: + """Execute the handler for a single event queue entry. + + This method contains the actual event-processing logic. The base + implementation simply invokes the registered handler function with the + event payload. Subclasses (e.g. ``BaseStateEventProcessor``) override + this method to add state management, delta emission, and middleware. + + ``_process_event_queue_entry`` is responsible for setting up the + ``EventContext`` and ensuring sequential ordering *before* calling this + method. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + """ + event = entry.event + result = registered_handler.handler.fn(**event.payload) + if inspect.isawaitable(result): + await result + + async def _process_event_queue_entry( + self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + ) -> None: + """Process a single event queue entry. + + This function runs in a new task for each event. It sets up the + ``EventContext``, enforces sequential ordering for non-background + events, and then delegates to ``_execute_event`` for the actual + handler invocation. + + Subclasses should override ``_execute_event`` rather than this method + so that the shared context setup and sequential-ordering logic is + always applied. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + """ + # Set up the event context for this task. + EventContext.set(entry.ctx) + await self._execute_event(entry=entry, registered_handler=registered_handler) + + def _create_event_task( + self, + *, + entry: EventQueueEntry, + registered_handler: RegisteredEventHandler, + ) -> asyncio.Task: + """Create and register an asyncio task for processing a single event. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + + Returns: + The created asyncio.Task. + """ + task = asyncio.create_task( + self._process_event_queue_entry( + entry=entry, registered_handler=registered_handler + ), + name=f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}", + ) + if sys.version_info < (3, 12): + task._event_ctx = entry.ctx # pyright: ignore[reportAttributeAccessIssue] + self._tasks[entry.ctx.txid] = task + task.add_done_callback(self._finish_task) + return task + + def _enqueue_for_token( + self, + *, + entry: EventQueueEntry, + registered_handler: RegisteredEventHandler, + ) -> None: + """Append an event to the per-token queue and dispatch if idle. + + If no queue exists for the token yet, one is created. If this is + the first (and therefore only) entry, a task is dispatched + immediately. + + Args: + entry: The event queue entry to enqueue. + registered_handler: The registered handler for the event. + """ + token = entry.ctx.token + token_queue = self._token_queues.get(token) + if token_queue is None: + token_queue = self._token_queues[token] = collections.deque() + token_queue.append((entry, registered_handler)) + if len(token_queue) == 1: + self._dispatch_next_for_token(token) + + def _dispatch_next_for_token(self, token: str) -> None: + """Create a task for the front entry in the per-token queue. + + Args: + token: The client token whose queue to dispatch from. + """ + token_queue = self._token_queues.get(token) + if not token_queue: + return + entry, registered_handler = token_queue[0] + # Skip cancelled futures. + future = self._futures.get(entry.ctx.txid) + if future is not None and future.cancelled(): + self._try_clean_future(future) + token_queue.popleft() + if token_queue: + self._dispatch_next_for_token(token) + else: + del self._token_queues[token] + return + self._create_event_task(entry=entry, registered_handler=registered_handler) + + async def _process_queue(self): + """Process events from the queue in a task.""" + if (queue := self._queue) is None: + msg = "Event processor is not running, call .start(...) first." + raise RuntimeError(msg) + with contextlib.suppress(QueueShutDown): + while True: + entry = await queue.get() + if ( + future := self._futures.get(entry.ctx.txid) + ) is not None and future.cancelled(): + self._try_clean_future(future) + queue.task_done() + continue + try: + try: + registered_handler = RegistrationContext.get().event_handlers[ + entry.event.name + ] + except KeyError as ke: + msg = ( + f"No registered handler found for event: {entry.event.name}" + ) + raise KeyError(msg) from ke + if registered_handler.handler.is_background: + # Background events run immediately, bypassing per-token ordering. + self._create_event_task( + entry=entry, registered_handler=registered_handler + ) + else: + # Sequential events go through the per-token queue. + self._enqueue_for_token( + entry=entry, registered_handler=registered_handler + ) + except Exception: + # Log the error and continue processing the next events. + console.error( + rich.markup.escape( + f"Error processing event queue entry for {entry.event} [txid={entry.ctx.txid}]:\n{traceback.format_exc()}" + ) + ) + queue.task_done() + if self._queue_task is asyncio.current_task(): + self._queue_task = None + + async def _handle_backend_exception( + self, ex: Exception, ev_ctx: EventContext | None = None + ) -> None: + """Handle an exception raised during event processing by calling the backend exception handler if it exists. + + Args: + ex: The exception that was raised. + ev_ctx: The event context for the exception, if available. This will be set in the context variable when calling the exception handler. + """ + if self.backend_exception_handler is not None: + if ev_ctx is not None: + EventContext.set(ev_ctx) + self.backend_exception_handler(ex) + + def _finish_task(self, task: asyncio.Task): + """Callback for finishing a _process_event_queue_entry task. + + This function is responsible for calling the backend exception handler + if the task raised an exception, and logging any errors that occur + during the process. + + Args: + task: The task that finished. + """ + from reflex.utils import telemetry + + if sys.version_info < (3, 12): + # py3.11 compat + task_ctx = task._event_ctx # type: ignore[attr-defined] + else: + task_ctx = task.get_context().run(EventContext.get) + self._tasks.pop(task_ctx.txid, None) + # Chain the next sequential event for this token if applicable. + token_queue = self._token_queues.get(task_ctx.token) + if token_queue and token_queue[0][0].ctx.txid == task_ctx.txid: + token_queue.popleft() + if token_queue: + self._dispatch_next_for_token(task_ctx.token) + else: + del self._token_queues[task_ctx.token] + future = self._futures.get(task_ctx.txid) + if task.done(): + try: + result = task.result() + except asyncio.CancelledError: + if future is not None and not future.done(): + future.cancel() + except Exception as ex: + if future is not None and not future.done(): + future.set_exception(ex) + with contextlib.suppress(BaseException): + # Trigger the future to avoid warnings if the caller didn't wait. + future.result() + telemetry.send_error(ex, context="backend") + if ( + not task.get_name().startswith("reflex_backend_exception_handler|") + and self.backend_exception_handler is not None + ): + # Create a new task in the same context to invoke the exception handler. + t = self._tasks[task_ctx.txid] = asyncio.create_task( + self._handle_backend_exception(ex, ev_ctx=task_ctx), + name=f"reflex_backend_exception_handler|task=[{task.get_name()}]|{time.time()}", + ) + if sys.version_info < (3, 12): + t._event_ctx = task_ctx # pyright: ignore[reportAttributeAccessIssue] + t.add_done_callback(self._finish_task) + return + console.error( + rich.markup.escape( + f"Error in {task.get_name()} [txid={task_ctx.txid}]:\n{traceback.format_exc()}" + ) + ) + else: + if future is not None and not future.done(): + future.set_result(result) + + +__all__ = [ + "EventFuture", + "EventProcessor", + "EventQueueEntry", +] diff --git a/packages/reflex-base/src/reflex_base/event/processor/future.py b/packages/reflex-base/src/reflex_base/event/processor/future.py new file mode 100644 index 00000000000..01d27fbdcef --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/future.py @@ -0,0 +1,100 @@ +"""EventFuture: a future that tracks child futures for hierarchical event processing.""" + +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +from typing import Any + + +@dataclasses.dataclass(kw_only=True, slots=True, eq=False) +class EventFuture(asyncio.Future): + """A future that tracks child futures for hierarchical event processing. + + When events are chained (a handler enqueues additional events), the child + futures are tracked so callers can wait for the entire chain to complete. + """ + + # The transaction id associated with this future. + txid: str + + # Child futures spawned by this future, if any. + children: list[EventFuture] = dataclasses.field(default_factory=list) + + # The parent future that spawned this one, or None if this future was + # enqueued directly from the queue rather than chained from another event. + parent: EventFuture | None = dataclasses.field(default=None, repr=False) + + # The event loop that this future is running on. + loop: asyncio.AbstractEventLoop = dataclasses.field( + default_factory=asyncio.get_running_loop, repr=False + ) + + def __post_init__(self) -> None: + """Call Future.__init__ for the EventFuture.""" + super(EventFuture, self).__init__(loop=self.loop) + + def add_child(self, child: EventFuture) -> None: + """Add a child future to this tracked future. + + Args: + child: The child EventFuture to add. + + Raises: + RuntimeError: If this future is already done. + """ + if self.done(): + msg = "Cannot add a child to an EventFuture that is already done." + raise RuntimeError(msg) + self.children.append(child) + + def all_done(self) -> bool: + """Check if this future and all descendant futures are done. + + Returns: + True if this future and all descendants have completed. + """ + if not self.done(): + return False + return all(child.all_done() for child in self.children) + + async def wait_all(self) -> Any: + """Wait for this future and all descendant futures to complete. + + Walks the children list by index so that children added after + iteration begins are still awaited. + + Child exceptions are suppressed since they are handled independently + by the event processor's _finish_task callback. + + Returns: + The result of this future. + """ + result = await self + i = 0 + while i < len(self.children): + child = self.children[i] + with contextlib.suppress(Exception, asyncio.CancelledError): + await child.wait_all() + i += 1 + return result + + def cancel(self, msg: object = None) -> bool: + """Cancel this future and all descendant futures. + + Args: + msg: Optional cancellation message. + + Returns: + True if the future was successfully cancelled. + """ + result = super(EventFuture, self).cancel(msg) + for child in self.children: + child.cancel(msg) + return result + + +__all__ = [ + "EventFuture", +] diff --git a/packages/reflex-base/src/reflex_base/event/processor/timeout.py b/packages/reflex-base/src/reflex_base/event/processor/timeout.py new file mode 100644 index 00000000000..a527221ad95 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/timeout.py @@ -0,0 +1,52 @@ +"""DrainTimeoutManager: manages an optional combined timeout over multiple calls.""" + +from __future__ import annotations + +import dataclasses +import time + + +@dataclasses.dataclass(kw_only=True, slots=True) +class DrainTimeoutManager: + """Manages an optional combined timeout over multiple calls. + + Each time the context is entered, yield the remaining time until the + overall timeout is reached, or 0 if the timeout has already been reached. + This allows multiple operations to share a single overall timeout, even if + they are not executed sequentially. + """ + + drain_deadline: float | None = None + + @classmethod + def with_timeout(cls, timeout: float | None) -> DrainTimeoutManager: + """Create a DrainTimeoutManager with a specified timeout. + + Args: + timeout: The overall amount of time in seconds to wait. + + Returns: + A DrainTimeoutManager instance with the drain deadline set. + """ + if timeout is None: + return cls(drain_deadline=None) + return cls(drain_deadline=time.time() + timeout) + + def __enter__(self) -> float: + """Enter the context and yield the remaining time. + + Returns: + The remaining time in seconds until the overall timeout is reached, or 0 if the timeout + has already been reached. + """ + if self.drain_deadline is not None: + return max(0, self.drain_deadline - time.time()) + return 0 + + def __exit__(self, *exc_info) -> None: + """Exit the context. No cleanup necessary.""" + + +__all__ = [ + "DrainTimeoutManager", +] diff --git a/packages/reflex-base/src/reflex_base/plugins/_screenshot.py b/packages/reflex-base/src/reflex_base/plugins/_screenshot.py index 8f1893e1d80..5d828193be2 100644 --- a/packages/reflex-base/src/reflex_base/plugins/_screenshot.py +++ b/packages/reflex-base/src/reflex_base/plugins/_screenshot.py @@ -97,7 +97,8 @@ async def clone_state(request: "Request") -> "Response": from starlette.responses import JSONResponse - from reflex.state import _substate_key + from reflex.istate.manager.token import BaseStateToken + from reflex.state import State if not app.event_namespace: return JSONResponse({}) @@ -109,7 +110,9 @@ async def clone_state(request: "Request") -> "Response": {"error": "Token to clone must be a string."}, status_code=400 ) - old_state = await app.state_manager.get_state(token_to_clone) + old_state = await app.state_manager.get_state( + BaseStateToken(ident=token_to_clone, cls=State), + ) new_state = _deep_copy(old_state) @@ -132,7 +135,8 @@ async def clone_state(request: "Request") -> "Response": found_new = True await app.state_manager.set_state( - _substate_key(new_token, new_state), new_state + BaseStateToken(ident=new_token, cls=type(new_state)), + new_state, ) return JSONResponse(new_token) diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py new file mode 100644 index 00000000000..8caa1d2b2c3 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -0,0 +1,160 @@ +"""A contextual registry for state and event handlers.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING + +from typing_extensions import Self + +from reflex_base.context.base import BaseContext +from reflex_base.utils.exceptions import StateValueError + +if TYPE_CHECKING: + from reflex.state import BaseState + from reflex_base.components.component import StatefulComponent + from reflex_base.event import EventHandler + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class RegisteredEventHandler: + """A registered event handler, which includes the handler and its full name.""" + + handler: EventHandler + states: tuple[type[BaseState], ...] + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) +class RegistrationContext(BaseContext): + """Context for registering event handlers and states.""" + + event_handlers: dict[str, RegisteredEventHandler] = dataclasses.field( + default_factory=dict, + repr=False, + ) + base_states: dict[str, type[BaseState]] = dataclasses.field( + default_factory=dict, + repr=False, + ) + base_state_substates: dict[str, set[type[BaseState]]] = dataclasses.field( + default_factory=dict, + repr=False, + ) + tag_to_stateful_component: dict[str, StatefulComponent] = dataclasses.field( + default_factory=dict, + repr=False, + ) + + @classmethod + def ensure_context(cls) -> Self: + """Ensure the context is attached, or create a new instance and attach it. + + Returns: + The registration context instance. + """ + try: + return cls.get() + except LookupError: + # If the context is not attached, create a new instance and attach it. + ctx = cls() + cls._context_var.set(ctx) + return ctx + + @classmethod + def register_base_state(cls, state_cls: type[BaseState]) -> type[BaseState]: + """Register a base state class with its full name. + + Also registers parent_state until finding one that is already registered. + + Args: + state_cls: The base state class to register. + + Returns: + The registered base state class. + """ + return cls.ensure_context()._register_base_state(state_cls) + + def _register_base_state(self, state_cls: type[BaseState]) -> type[BaseState]: + """Register a base state class with its full name. + + Also registers parent_state until finding one that is already registered. + + Args: + state_cls: The base state class to register. + + Returns: + The registered base state class. + """ + self.base_states[state_cls.get_full_name()] = state_cls + for event_handler in state_cls.event_handlers.values(): + self._register_event_handler(event_handler, states=(state_cls,)) + if (parent_state := state_cls.get_parent_state()) is not None: + if parent_state.get_full_name() not in self.base_states: + self._register_base_state(parent_state) + parent_state_substates = self.base_state_substates.setdefault( + parent_state.get_full_name(), set() + ) + if state_cls in parent_state_substates: + msg = ( + f"State class {state_cls.get_full_name()} is already registered as a substate of " + f"{parent_state.get_full_name()}. This likely means there are multiple classes with the same name " + "in the same module, which causes a conflict in the registry. Please rename one of the classes to avoid " + "shadowing. Shadowing substate classes is not allowed." + ) + raise StateValueError(msg) + parent_state_substates.add(state_cls) + return state_cls + + @classmethod + def register_event_handler( + cls, handler: EventHandler, states: tuple[type[BaseState], ...] = () + ) -> EventHandler: + """Register an event handler with its full name and associated states. + + Args: + handler: The event handler to register. + states: The states associated with the event handler. + + Returns: + The registered event handler. + """ + return cls.ensure_context()._register_event_handler(handler, states=states) + + def _register_event_handler( + self, + handler: EventHandler, + states: tuple[type[BaseState], ...] = (), + ) -> EventHandler: + """Register an event handler with its full name and associated states. + + Args: + handler: The event handler to register. + states: The states associated with the event handler. + + Returns: + The registered event handler. + """ + from reflex.utils.format import format_event_handler + + full_name = format_event_handler(handler) + self.event_handlers[full_name] = RegisteredEventHandler( + handler=handler, states=states + ) + return handler + + def get_substates( + self, base_state_cls: type[BaseState] | str + ) -> set[type[BaseState]]: + """Get the substates for a base state class. + + Args: + base_state_cls: The base state class to get substates for. + + Returns: + A set of substate classes. + """ + if isinstance(base_state_cls, str): + return self.base_state_substates.setdefault(base_state_cls, set()) + return self.base_state_substates.setdefault( + base_state_cls.get_full_name(), set() + ) diff --git a/packages/reflex-base/src/reflex_base/utils/format.py b/packages/reflex-base/src/reflex_base/utils/format.py index bbfa9208fd5..c8f11e01f02 100644 --- a/packages/reflex-base/src/reflex_base/utils/format.py +++ b/packages/reflex-base/src/reflex_base/utils/format.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any from reflex_base import constants -from reflex_base.constants.state import FRONTEND_EVENT_STATE from reflex_base.utils import exceptions if TYPE_CHECKING: @@ -454,25 +453,20 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: Returns: The state and function name. """ - # Get the class that defines the event handler. - parts = handler.fn.__qualname__.split(".") + # Get the name of the event function. + name = handler.fn.__qualname__ # Get the state full name - state_full_name = handler.state_full_name + state_full_name = handler.state.get_full_name() if handler.state else "" - # If there's no enclosing class, just return the function name. - if not state_full_name: - return ("", parts[-1]) + # If there's no enclosing state, just return the full name. + if handler.state is None: + return ("", name) - # Get the function name - name = parts[-1] + # Get the event name inside the state. + func_name = name.rpartition(".")[2] - from reflex.state import State - - if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__: - return ("", to_snake_case(handler.fn.__qualname__)) - - return (state_full_name, name) + return (state_full_name, func_name) def format_event_handler(handler: EventHandler) -> str: @@ -606,7 +600,7 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]: Returns: The reformatted query params """ - params = router_data[constants.RouteVar.QUERY] + params = router_data.get(constants.RouteVar.QUERY, {}) return {k.replace("-", "_"): v for k, v in params.items()} diff --git a/packages/reflex-base/src/reflex_base/utils/serializers.py b/packages/reflex-base/src/reflex_base/utils/serializers.py index 1ba85501284..d20a29e3890 100644 --- a/packages/reflex-base/src/reflex_base/utils/serializers.py +++ b/packages/reflex-base/src/reflex_base/utils/serializers.py @@ -8,6 +8,7 @@ import functools import inspect import json +import uuid import warnings from collections.abc import Callable, Mapping, Sequence from datetime import date, datetime, time, timedelta @@ -34,6 +35,16 @@ SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer) +deserializers = { + int: int, + float: float, + datetime: datetime.fromisoformat, + date: date.fromisoformat, + time: time.fromisoformat, + uuid.UUID: uuid.UUID, +} + + @overload def serializer( fn: None = None, diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 072227a7503..2e8d80b1052 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, BinaryIO, cast from python_multipart.multipart import MultipartParser, parse_options_header -from reflex_base import constants from reflex_base.utils import exceptions +from reflex_base.utils.format import json_dumps from starlette.datastructures import Headers from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.exceptions import HTTPException @@ -22,11 +22,9 @@ from typing_extensions import Self if TYPE_CHECKING: - from reflex_base.event import EventHandler from reflex_base.utils.types import Receive, Scope, Send from reflex.app import App - from reflex.state import BaseState @dataclasses.dataclass(frozen=True) @@ -102,7 +100,7 @@ def __init__(self, *, maxsize: int = 8): self._condition = asyncio.Condition() self._closed = False self._error: Exception | None = None - self._consumer_task: asyncio.Task[Any] | None = None + self._consumer_task: asyncio.Future[Any] | None = None def __aiter__(self) -> Self: """Return the iterator itself. @@ -135,7 +133,7 @@ async def __anext__(self) -> UploadChunk: raise self._error raise StopAsyncIteration - def set_consumer_task(self, task: asyncio.Task[Any]) -> None: + def set_consumer_task(self, task: asyncio.Future[Any]) -> None: """Track the task consuming this iterator. Args: @@ -206,7 +204,7 @@ def _raise_if_consumer_finished(self) -> None: raise RuntimeError(msg) from task_exc raise RuntimeError(msg) - def _wake_waiters(self, task: asyncio.Task[Any]) -> None: + def _wake_waiters(self, task: asyncio.Future[Any]) -> None: """Wake any producers or consumers blocked on the iterator condition. Args: @@ -446,51 +444,6 @@ def _require_upload_headers(request: Request) -> tuple[str, str]: return token, handler -async def _get_upload_runtime_handler( - app: App, - token: str, - handler_name: str, -) -> tuple[BaseState, EventHandler]: - """Resolve the runtime state and event handler for an upload request. - - Args: - app: The Reflex app. - token: The client token. - handler_name: The fully qualified event handler name. - - Returns: - The root state instance and resolved event handler. - """ - from reflex.state import _substate_key - - substate_token = _substate_key(token, handler_name.rpartition(".")[0]) - state = await app.state_manager.get_state(substate_token) - _current_state, event_handler = state._get_event_handler(handler_name) - return state, event_handler - - -def _seed_upload_router_data(state: BaseState, token: str) -> None: - """Ensure upload-launched handlers have the client token in router state. - - Background upload handlers use ``StateProxy`` which derives its mutable-state - token from ``self.router.session.client_token``. Upload requests do not flow - through the normal websocket event pipeline, so we seed the token here. - - Args: - state: The root state instance. - token: The client token from the upload request. - """ - from reflex.state import RouterData - - router_data = dict(state.router_data) - if router_data.get(constants.RouteVar.CLIENT_TOKEN) == token: - return - - router_data[constants.RouteVar.CLIENT_TOKEN] = token - state.router_data = router_data - state.router = RouterData.from_router_data(router_data) - - async def _upload_buffered_file( request: Request, app: App, @@ -507,6 +460,8 @@ async def _upload_buffered_file( from reflex_base.event import Event from reflex_base.utils.exceptions import UploadValueError + from reflex.state import StateUpdate + try: form_data = await request.form() except ClientDisconnect: @@ -545,7 +500,6 @@ def _create_upload_event() -> Event: ) return Event( - token=token, name=handler_name, payload={handler_upload_param[0]: file_uploads}, ) @@ -567,12 +521,9 @@ async def _ndjson_updates(): Yields: Each state update as newline-delimited JSON. """ - async with app.state_manager.modify_state_with_links( - event.substate_token, event=event - ) as state: - async for update in state._process(event): - update = await app._postprocess(state, event, update) - yield update.json() + "\n" + # Enqueue the task on the main event loop, but emit deltas to the local queue. + async for delta in app.event_processor.enqueue_stream_delta(token, event): + yield json_dumps(StateUpdate(delta=delta)) + "\n" return _UploadStreamingResponse( _ndjson_updates(), @@ -583,10 +534,9 @@ async def _ndjson_updates(): def _background_upload_accepted_response() -> StreamingResponse: """Return a minimal ndjson response for background upload dispatch.""" - from reflex.state import StateUpdate def _accepted_updates(): - yield StateUpdate(final=True).json() + "\n" + yield "{}\n" return StreamingResponse( _accepted_updates(), @@ -613,23 +563,12 @@ async def _upload_chunk_file( chunk_iter = UploadChunkIterator(maxsize=8) event = Event( - token=token, name=handler_name, payload={handler_upload_param[0]: chunk_iter}, ) + task_future = await app.event_processor.enqueue(token, event) - async with app.state_manager.modify_state_with_links( - event.substate_token, - event=event, - ) as state: - _seed_upload_router_data(state, token) - task = app._process_background(state, event) - - if task is None: - msg = f"@rx.event(background=True) is required for upload_files_chunk handler `{handler_name}`." - return JSONResponse({"detail": msg}, status_code=400) - - chunk_iter.set_consumer_task(task) + chunk_iter.set_consumer_task(task_future) parser = _UploadChunkMultipartParser( headers=request.headers, @@ -640,9 +579,9 @@ async def _upload_chunk_file( try: await parser.parse() except ClientDisconnect: - task.cancel() + task_future.cancel() with contextlib.suppress(asyncio.CancelledError): - await task + await task_future return Response() except (MultiPartException, RuntimeError, ValueError) as err: await chunk_iter.fail(err) @@ -686,11 +625,13 @@ async def upload_file(request: Request): resolve_upload_chunk_handler_param, resolve_upload_handler_param, ) + from reflex_base.registry import RegistrationContext token, handler_name = _require_upload_headers(request) - _state, event_handler = await _get_upload_runtime_handler( - app, token, handler_name - ) + registered_event_handler = RegistrationContext.get().event_handlers[ + handler_name + ] + event_handler = registered_event_handler.handler if event_handler.is_background: try: diff --git a/pyi_hashes.json b/pyi_hashes.json index 7a0db4b66a2..f6db8e53593 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -118,7 +118,7 @@ "packages/reflex-components-recharts/src/reflex_components_recharts/polar.pyi": "db5298160144f23ae7abcaac68e845c7", "packages/reflex-components-recharts/src/reflex_components_recharts/recharts.pyi": "75150b01510bdacf2c97fca347c86c59", "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "dc43e142b089b1158588e999505444f6", - "reflex/__init__.pyi": "9321a11f6891d792fcd921cc1bdc64f4", + "reflex/__init__.pyi": "5de3d4af8ea86e9755f622510b868196", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", "reflex/experimental/memo.pyi": "c10cbc554fe2ffdb3a008b59bc503936" } diff --git a/pyproject.toml b/pyproject.toml index 45f81ecad75..39fcbed94c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,7 +234,23 @@ ignore-words-list = "te, TreeE, selectin" [tool.coverage.run] -source = ["reflex"] +source = [ + "reflex", + "reflex_components_code", + "reflex_components_core", + "reflex_components_dataeditor", + "reflex_components_gridjs", + "reflex_components_lucide", + "reflex_components_markdown", + "reflex_components_moment", + "reflex_components_plotly", + "reflex_components_radix", + "reflex_components_react_player", + "reflex_components_recharts", + "reflex_components_sonner", + "reflex_base", + "reflex_docgen", +] branch = true omit = [ "*/pyi_generator.py", @@ -247,7 +263,7 @@ omit = [ [tool.coverage.report] show_missing = true # TODO bump back to 79 -fail_under = 50 +fail_under = 72 precision = 2 ignore_errors = true diff --git a/reflex/__init__.py b/reflex/__init__.py index 6aa278482f3..ca246b89616 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -219,6 +219,7 @@ "LocalStorage", "SessionStorage", ], + "istate.manager.token": ["StateToken", "BaseStateToken"], "middleware": ["middleware", "Middleware"], "model": ["asession", "session", "Model", "ModelRegistry"], "page": ["page"], diff --git a/reflex/app.py b/reflex/app.py index f9aa6f5aa5c..e3e66aec45d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -15,20 +15,13 @@ import time import traceback import urllib.parse -from collections.abc import ( - AsyncGenerator, - AsyncIterator, - Callable, - Coroutine, - Mapping, - Sequence, -) +from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence from datetime import datetime from itertools import chain from pathlib import Path from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ParamSpec +from typing import TYPE_CHECKING, Any, ParamSpec, overload from reflex_base import constants from reflex_base.components.component import ( @@ -45,9 +38,10 @@ EventSpec, EventType, IndividualEventType, - get_hydrate_event, noop, ) +from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor +from reflex_base.registry import RegistrationContext from reflex_base.utils import console from reflex_base.utils.imports import ImportVar from reflex_base.utils.types import ASGIApp, Message, Receive, Scope, Send @@ -87,7 +81,7 @@ ) from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager import StateManager, StateModificationContext -from reflex.istate.proxy import StateProxy +from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES from reflex.route import ( get_route_args, @@ -99,8 +93,6 @@ RouterData, State, StateUpdate, - _split_substate_key, - _substate_key, all_base_state_classes, code_uses_state_contexts, ) @@ -122,6 +114,11 @@ from reflex.utils.misc import run_in_thread from reflex.utils.token_manager import RedisTokenManager, TokenManager +if sys.version_info < (3, 13): + from typing_extensions import deprecated +else: + from warnings import deprecated + if TYPE_CHECKING: from reflex_base.vars import Var @@ -155,24 +152,26 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec: """ from reflex_components_sonner.toast import toast - error = traceback.format_exc() + error = traceback.format_exception( + type(exception), exception, exception.__traceback__ + ) - console.error(f"[Reflex Backend Exception]\n {error}\n") + console.error(f"[Reflex Backend Exception]\n {''.join(error)}\n") error_message = ( ["Contact the website administrator."] if is_prod_mode() - else [f"{type(exception).__name__}: {exception}.", "See logs for details."] + else [f"{type(exception).__name__}: {exception}", "See logs for details."] ) return toast( "An error occurred.", level="error", fallback_to_alert=True, - description="
".join(error_message), + description="\n".join(error_message), position="top-center", id="backend_error", - style={"width": "500px"}, + style={"width": "500px", "white-space": "pre-wrap"}, ) @@ -397,8 +396,13 @@ class App(MiddlewareMixin, LifespanMixin): # The async server name space. _event_namespace: EventNamespace | None = None - # Background tasks that are currently running. - _background_tasks: set[asyncio.Task] = dataclasses.field(default_factory=set) + # The processor queue for handling events. + _event_processor: EventProcessor | None = None + + # Store the RegistrationContext to apply inside the ASGI callable task. + _registration_context: RegistrationContext = dataclasses.field( + default_factory=RegistrationContext.ensure_context + ) frontend_exception_handler: Callable[[Exception], None] = ( default_frontend_exception_handler @@ -426,6 +430,18 @@ def event_namespace(self) -> EventNamespace | None: """ return self._event_namespace + @property + def event_processor(self) -> EventProcessor: + """Get the event processor. + + Raises: + RuntimeError: If the event processor is not initialized. + """ + if self._event_processor is None: + msg = "Event processor is not initialized." + raise RuntimeError(msg) + return self._event_processor + def __post_init__(self): """Initialize the app. @@ -486,7 +502,7 @@ def _setup_state(self) -> None: config = get_config() # Set up the state manager. - self._state_manager = StateManager.create(state=self._state) + self._state_manager = StateManager.create() # Set up the Socket.IO AsyncServer. if not self.sio: @@ -559,6 +575,40 @@ async def modified_send(message: Message): # Check the exception handlers self._validate_exception_handlers() + # Ensure the event processor starts and stops with the server. + self.register_lifespan_task(self._setup_event_processor) + + def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp: + """Ensure the RegistrationContext is attached to the ASGI app. + + Args: + app: The ASGI app to attach the middleware to. + + Returns: + The ASGI app with the middleware attached. + """ + + async def registration_context_middleware( + scope: Scope, receive: Receive, send: Send + ): + if self._registration_context is not None: + RegistrationContext.set(self._registration_context) + await app(scope, receive, send) + + return registration_context_middleware + + @contextlib.asynccontextmanager + async def _setup_event_processor(self) -> AsyncIterator[None]: + # Create the event processor. + self._event_processor = BaseStateEventProcessor( + middleware=self, backend_exception_handler=self.backend_exception_handler + ) + async with self._event_processor.configure( + state_manager=self.state_manager, + event_namespace=self.event_namespace, + ): + yield + def __repr__(self) -> str: """Get the string representation of the app. @@ -630,9 +680,12 @@ def __call__(self) -> ASGIApp: asgi_app = api_transformer(asgi_app) top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks) - top_asgi_app.mount("", asgi_app) + # Make sure the RegistrationContext is attached. + top_asgi_app.mount( + "", + self._registration_context_middleware(asgi_app), + ) App._add_cors(top_asgi_app) - return top_asgi_app def _add_default_endpoints(self): @@ -1103,7 +1156,7 @@ def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> No msg = f"ComputedVar {var._name} on state {state.__name__} has an invalid dependency {state_name}.{dep}" raise exceptions.VarDependencyError(msg) - for substate in state.class_subclasses: + for substate in state.get_substates(): self._validate_var_dependencies(substate) def _compile( @@ -1530,10 +1583,27 @@ def all_routes(_request: Request) -> Response: str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"] ) + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state( + self, + token: str, + background: bool = False, + previous_dirty_vars: dict[str, set[str]] | None = None, + ) -> contextlib.AbstractAsyncContextManager[BaseState]: ... + + @overload + def modify_state( + self, + token: BaseStateToken, + background: bool = False, + previous_dirty_vars: dict[str, set[str]] | None = None, + ) -> contextlib.AbstractAsyncContextManager[BaseState]: ... + @contextlib.asynccontextmanager async def modify_state( self, - token: str, + token: BaseStateToken | str, background: bool = False, previous_dirty_vars: dict[str, set[str]] | None = None, **context: Unpack[StateModificationContext], @@ -1555,6 +1625,9 @@ async def modify_state( msg = "App has not been initialized yet." raise RuntimeError(msg) + if isinstance(token, str): + token = BaseStateToken.from_legacy_token(token, root_state=self._state) + # Get exclusive access to the state. async with self.state_manager.modify_state_with_links( token, previous_dirty_vars=previous_dirty_vars, **context @@ -1566,64 +1639,10 @@ async def modify_state( if delta: # When the frontend vars are modified emit the delta to the frontend. await self.event_namespace.emit_update( - update=StateUpdate( - delta=delta, - final=True if not background else None, - ), - token=token, - ) - - def _process_background( - self, state: BaseState, event: Event - ) -> asyncio.Task | None: - """Process an event in the background and emit updates as they arrive. - - Args: - state: The state to process the event for. - event: The event to process. - - Returns: - Task if the event was backgroundable, otherwise None - """ - substate, handler = state._get_event_handler(event) - - if not handler.is_background: - return None - - substate = StateProxy(substate, event) - - async def _coro(): - """Coroutine to process the event and emit updates inside an asyncio.Task. - - Raises: - RuntimeError: If the app has not been initialized yet. - """ - if self.event_namespace is None: - msg = "App has not been initialized yet." - raise RuntimeError(msg) - - # Process the event. - async for update in state._process_event( - handler=handler, state=substate, payload=event.payload - ): - # Postprocess the event. - update = await self._postprocess(state, event, update) - - # Send the update to the client. - await self.event_namespace.emit_update( - update=update, - token=event.token, + update=StateUpdate(delta=delta), + token=token.ident, ) - task = asyncio.create_task( - _coro(), - name=f"reflex_background_task|{event.name}|{time.time()}|{event.token}", - ) - self._background_tasks.add(task) - # Clean up task from background_tasks set when complete. - task.add_done_callback(self._background_tasks.discard) - return task - def _validate_exception_handlers(self): """Validate the custom event exception handlers for front- and backend. @@ -1717,95 +1736,6 @@ def _validate_exception_handlers(self): raise ValueError(msg) -async def process( - app: App, event: Event, sid: str, headers: dict, client_ip: str -) -> AsyncGenerator[StateUpdate]: - """Process an event. - - Args: - app: The app to process the event for. - event: The event to process. - sid: The Socket.IO session id. - headers: The client headers. - client_ip: The client_ip. - - Yields: - The state updates after processing the event. - - Raises: - Exception: If a reflex specific error occurs during processing the event. - """ - from reflex.utils import telemetry - - try: - # Add request data to the state. - router_data = event.router_data - router_data.update({ - constants.RouteVar.QUERY: format.format_query_params(event.router_data), - constants.RouteVar.CLIENT_TOKEN: event.token, - constants.RouteVar.SESSION_ID: sid, - constants.RouteVar.HEADERS: headers, - constants.RouteVar.CLIENT_IP: client_ip, - }) - # Get the state for the session exclusively. - async with app.state_manager.modify_state_with_links( - event.substate_token, event=event - ) as state: - # When this is a brand new instance of the state, signal the - # frontend to reload before processing it. - if ( - not state.router_data - and event.name != get_hydrate_event(state) - and app.event_namespace is not None - ): - await asyncio.create_task( - app.event_namespace.emit( - "reload", - data=event, - to=sid, - ), - name=f"reflex_emit_reload|{event.name}|{time.time()}|{event.token}", - ) - return - router_data[constants.RouteVar.PATH] = "/" + ( - app.router(path) or "404" - if (path := router_data.get(constants.RouteVar.PATH)) - else "404" - ).removeprefix("/") - # re-assign only when the value is different - if state.router_data != router_data: - # assignment will recurse into substates and force recalculation of - # dependent ComputedVar (dynamic route variables) - state.router_data = router_data - state.router = RouterData.from_router_data(router_data) - - # Preprocess the event. - update = await app._preprocess(state, event) - - # If there was an update, yield it. - if update is not None: - yield update - - # Only process the event if there is no update. - else: - if app._process_background(state, event) is not None: - # `final=True` allows the frontend send more events immediately. - yield StateUpdate(final=True) - else: - # Process the event synchronously. - async for update in state._process(event): - # Postprocess the event. - update = await app._postprocess(state, event, update) - - # Yield the update. - yield update - except Exception as ex: - telemetry.send_error(ex, context="backend") - - app.backend_exception_handler(ex) - raise - - def ping(_request: Request) -> Response: """Test API endpoint. @@ -1955,17 +1885,14 @@ async def emit_update(self, update: StateUpdate, token: str) -> None: update: The state update to send. token: The client token (tab) associated with the event. """ - client_token, _ = _split_substate_key(token) - socket_record = self._token_manager.token_to_socket.get(client_token) + socket_record = self._token_manager.token_to_socket.get(token) if ( socket_record is None or socket_record.instance_id != self._token_manager.instance_id ): if isinstance(self._token_manager, RedisTokenManager): # The socket belongs to another instance of the app, send it to the lost and found. - if not await self._token_manager.emit_lost_and_found( - client_token, update - ): + if not await self._token_manager.emit_lost_and_found(token, update): console.warn( f"Failed to send delta to lost and found for client {token!r}" ) @@ -1993,6 +1920,13 @@ async def on_event(self, sid: str, data: Any): RuntimeError: If the Socket.IO is badly initialized. EventDeserializationError: If the event data is not a dictionary. """ + # Determine the token for this SID + if (token := self.sid_to_token.get(sid)) is None: + console.warn( + f"Received event from session {sid} with no associated token. This may indicate a bug. Event data: {data}" + ) + return + fields = data if isinstance(fields, str): @@ -2017,14 +1951,6 @@ async def on_event(self, sid: str, data: Any): msg = f"Failed to deserialize event data: {fields}." raise exceptions.EventDeserializationError(msg) from ex - # Correct the token if it doesn't match what we expect for this SID - expected_token = self.sid_to_token.get(sid) - if expected_token and event.token != expected_token: - # Create new event with corrected token since Event is frozen - from dataclasses import replace - - event = replace(event, token=expected_token) - # Get the event environment. if self.app.sio is None: msg = "Socket.IO is not initialized." @@ -2057,14 +1983,20 @@ async def on_event(self, sid: str, data: Any): .partition(",")[0] .strip() ) - - async with contextlib.aclosing( - process(self.app, event, sid, headers, client_ip) - ) as updates_gen: - # Process the events. - async for update in updates_gen: - # Emit the update from processing the event. - await self.emit_update(update=update, token=event.token) + router_data = event.router_data + router_data.update({ + constants.RouteVar.QUERY: format.format_query_params(event.router_data), + constants.RouteVar.CLIENT_TOKEN: token, + constants.RouteVar.SESSION_ID: sid, + constants.RouteVar.HEADERS: headers, + constants.RouteVar.CLIENT_IP: client_ip, + }) + router_data[constants.RouteVar.PATH] = "/" + ( + self.app.router(path) or "404" + if (path := router_data.get(constants.RouteVar.PATH)) + else "404" + ).removeprefix("/") + await self.app.event_processor.enqueue(token, event) async def on_ping(self, sid: str): """Event for testing the API endpoint. @@ -2090,8 +2022,9 @@ async def link_token_to_sid(self, sid: str, token: str): await self.emit("new_token", new_token, to=sid) # Update client state to apply new sid/token for running background tasks. - async with self.app.state_manager.modify_state( - _substate_key(new_token or token, self.app.state_manager.state) - ) as state: - state.router_data[constants.RouteVar.SESSION_ID] = sid - state.router = RouterData.from_router_data(state.router_data) + if self.app._state is not None: + async with self.app.state_manager.modify_state( + BaseStateToken(ident=new_token or token, cls=self.app._state) + ) as state: + state.router_data[constants.RouteVar.SESSION_ID] = sid + state.router = RouterData.from_router_data(state.router_data) diff --git a/reflex/app_mixins/middleware.py b/reflex/app_mixins/middleware.py index b8f54f71aa6..c9bea19e922 100644 --- a/reflex/app_mixins/middleware.py +++ b/reflex/app_mixins/middleware.py @@ -7,7 +7,7 @@ from reflex_base.event import Event -from reflex.middleware import HydrateMiddleware, Middleware +from reflex.middleware import Middleware from reflex.state import BaseState, StateUpdate from .mixin import AppMixin @@ -20,9 +20,6 @@ class MiddlewareMixin(AppMixin): # Middleware to add to the app. Users should use `add_middleware`. _middlewares: list[Middleware] = dataclasses.field(default_factory=list) - def _init_mixin(self): - self._middlewares.append(HydrateMiddleware()) - def add_middleware(self, middleware: Middleware, index: int | None = None): """Add middleware to the app. diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 2b067596db9..5e3d71ee0e6 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -4,17 +4,20 @@ import dataclasses from abc import ABC, abstractmethod from collections.abc import AsyncIterator -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict, overload from reflex_base import constants from reflex_base.config import get_config from reflex_base.event import Event from reflex_base.utils.exceptions import InvalidStateManagerModeError -from typing_extensions import ReadOnly, Unpack +from typing_extensions import ReadOnly, Unpack, deprecated -from reflex.state import BaseState +from reflex.istate.manager.token import TOKEN_TYPE, StateToken from reflex.utils import console, prerequisites +if TYPE_CHECKING: + from reflex.state import BaseState + class StateModificationContext(TypedDict, total=False): """The context for modifying state.""" @@ -27,21 +30,32 @@ class StateModificationContext(TypedDict, total=False): @dataclasses.dataclass class StateManager(ABC): - """A class to manage many client states. + """A class to manage many client states.""" - Attributes: - state: The state class to use. - """ + @property + def state(self): + """Get the state class. + + Deprecated: the state manager no longer holds a reference to the state class. + + Returns: + The State class. + """ + console.deprecate( + feature_name="StateManager.state", + reason="The state manager no longer holds a reference to the state class. " + "Use reflex.state.State directly instead.", + deprecation_version="0.9.0", + removal_version="1.0", + ) + from reflex.state import State - state: type[BaseState] + return State @classmethod - def create(cls, state: type[BaseState]): + def create(cls): """Create a new state manager. - Args: - state: The state class to use. - Returns: The state manager (either disk, memory or redis). @@ -57,11 +71,11 @@ def create(cls, state: type[BaseState]): if config.state_manager_mode == constants.StateManagerMode.MEMORY: from reflex.istate.manager.memory import StateManagerMemory - return StateManagerMemory(state=state) + return StateManagerMemory() if config.state_manager_mode == constants.StateManagerMode.DISK: from reflex.istate.manager.disk import StateManagerDisk - return StateManagerDisk(state=state) + return StateManagerDisk() if config.state_manager_mode == constants.StateManagerMode.REDIS: redis = prerequisites.get_redis() if redis is not None: @@ -69,7 +83,6 @@ def create(cls, state: type[BaseState]): # make sure expiration values are obtained only from the config object on creation return StateManagerRedis( - state=state, redis=redis, token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, @@ -78,8 +91,79 @@ def create(cls, state: type[BaseState]): msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" raise InvalidStateManagerModeError(msg) + @staticmethod + def _coerce_token(token: StateToken[TOKEN_TYPE] | str) -> StateToken[TOKEN_TYPE]: + """Convert a legacy string token to a StateToken if needed. + + Args: + token: The token, either a StateToken or legacy string. + + Returns: + The coerced StateToken. + """ + if isinstance(token, str): + from reflex.istate.manager.token import BaseStateToken + from reflex.state import State + + return BaseStateToken.from_legacy_token(token, root_state=State) # type: ignore[return-value] + return token + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + async def get_state(self, token: str) -> "BaseState": ... + + @overload + async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + async def set_state( + self, + token: str, + state: "BaseState", + **context: Unpack[StateModificationContext], + ) -> None: ... + + @overload + async def set_state( + self, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, + **context: Unpack[StateModificationContext], + ) -> None: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state( + self, token: str, **context: Unpack[StateModificationContext] + ) -> contextlib.AbstractAsyncContextManager["BaseState"]: ... + + @overload + def modify_state( + self, + token: StateToken[TOKEN_TYPE], + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state_with_links( + self, + token: str, + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager["BaseState"]: ... + + @overload + def modify_state_with_links( + self, + token: StateToken[TOKEN_TYPE], + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + @abstractmethod - async def get_state(self, token: str) -> BaseState: + async def get_state(self, token: StateToken[TOKEN_TYPE] | str) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -92,8 +176,8 @@ async def get_state(self, token: str) -> BaseState: @abstractmethod async def set_state( self, - token: str, - state: BaseState, + token: StateToken[TOKEN_TYPE] | str, + state: TOKEN_TYPE, **context: Unpack[StateModificationContext], ): """Set the state for a token. @@ -107,8 +191,10 @@ async def set_state( @abstractmethod @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, + token: StateToken[TOKEN_TYPE] | str, + **context: Unpack[StateModificationContext], + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -118,15 +204,15 @@ async def modify_state( Yields: The state for the token. """ - yield self.state() + yield # pyright: ignore[reportReturnType] @contextlib.asynccontextmanager async def modify_state_with_links( self, - token: str, + token: StateToken[TOKEN_TYPE] | str, previous_dirty_vars: dict[str, set[str]] | None = None, **context: Unpack[StateModificationContext], - ) -> AsyncIterator[BaseState]: + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token, including linked substates, while holding exclusive lock. Args: @@ -137,8 +223,14 @@ async def modify_state_with_links( Yields: The state for the token with linked states patched in. """ + from reflex.state import BaseState + + token = self._coerce_token(token) async with self.modify_state(token, **context) as root_state: - if getattr(root_state, "_reflex_internal_links", None) is not None: + if ( + isinstance(root_state, BaseState) + and getattr(root_state, "_reflex_internal_links", None) is not None + ): from reflex.istate.shared import SharedStateBaseInternal shared_state = await root_state.get_state(SharedStateBaseInternal) @@ -177,4 +269,6 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - return prerequisites.get_and_validate_app().app.state_manager + from reflex_base.event.context import EventContext + + return EventContext.get().state_manager diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 3b3c15b1cb1..9b75533ad93 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -8,6 +8,7 @@ from collections.abc import AsyncIterator from hashlib import md5 from pathlib import Path +from typing import Any, Generic, cast from reflex_base.environment import environment from typing_extensions import Unpack, override @@ -17,17 +18,18 @@ StateModificationContext, _default_token_expiration, ) -from reflex.state import BaseState, _split_substate_key, _substate_key +from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken +from reflex.state import BaseState from reflex.utils import console, path_ops, prerequisites from reflex.utils.misc import run_in_thread @dataclasses.dataclass(frozen=True) -class QueueItem: +class QueueItem(Generic[TOKEN_TYPE]): """An item in the write queue.""" - token: str - state: BaseState + token: StateToken[TOKEN_TYPE] + state: TOKEN_TYPE timestamp: float @@ -36,7 +38,7 @@ class StateManagerDisk(StateManager): """A state manager that stores states on disk.""" # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + states: dict[str, Any] = dataclasses.field(default_factory=dict) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) @@ -57,7 +59,7 @@ class StateManagerDisk(StateManager): ) # Pending writes - _write_queue: dict[str, QueueItem] = dataclasses.field( + _write_queue: dict[StateToken, QueueItem] = dataclasses.field( default_factory=dict, init=False, ) @@ -96,7 +98,7 @@ def _purge_expired_states(self): # remove the file path.unlink() - def token_path(self, token: str) -> Path: + def token_path(self, token: StateToken) -> Path: """Get the path for a token. Args: @@ -106,10 +108,10 @@ def token_path(self, token: str) -> Path: The path for the token. """ return ( - self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" + self.states_directory / f"{md5(str(token).encode()).hexdigest()}.pkl" ).absolute() - async def load_state(self, token: str) -> BaseState | None: + async def load_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE | None: """Load a state object based on the provided token. Args: @@ -123,23 +125,23 @@ async def load_state(self, token: str) -> BaseState | None: if token_path.exists(): try: with token_path.open(mode="rb") as file: - return BaseState._deserialize(fp=file) + return token.deserialize(fp=file) except Exception: pass return None async def populate_substates( - self, client_token: str, state: BaseState, root_state: BaseState + self, token: BaseStateToken, state: BaseState, root_state: BaseState ): """Populate the substates of a state object. Args: - client_token: The client token. + token: The token used to identify the state object. state: The state object to populate. root_state: The root state object. """ for substate in state.get_substates(): - substate_token = _substate_key(client_token, substate) + substate_token = token.with_cls(substate) fresh_instance = await root_state.get_state(substate) instance = await self.load_state(substate_token) @@ -151,13 +153,13 @@ async def populate_substates( state.substates[substate.get_name()] = instance instance.parent_state = state - await self.populate_substates(client_token, instance, root_state) + await self.populate_substates(token, instance, root_state) @override async def get_state( self, - token: str, - ) -> BaseState: + token: StateToken[TOKEN_TYPE], + ) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -166,38 +168,51 @@ async def get_state( Returns: The state for the token. """ - client_token = _split_substate_key(token)[0] - self._token_last_touched[client_token] = time.time() - root_state = self.states.get(client_token) + token = self._coerce_token(token) + root_state = self.states.get(token.cache_key) + self._token_last_touched[token.cache_key] = time.time() if root_state is not None: # Retrieved state from memory. return root_state # Deserialize root state from disk. - root_state = await self.load_state(_substate_key(client_token, self.state)) - # Create a new root state tree with all substates instantiated. - fresh_root_state = self.state(_reflex_internal_init=True) - if root_state is None: - root_state = fresh_root_state - else: - # Ensure all substates exist, even if they were not serialized previously. - root_state.substates = fresh_root_state.substates - self.states[client_token] = root_state - await self.populate_substates(client_token, root_state, root_state) - return root_state - - async def set_state_for_substate(self, client_token: str, substate: BaseState): + if isinstance(token, BaseStateToken): + # Find the root state + root_state_cls = token.cls.get_root_state() + root_state = await self.load_state(token.with_cls(root_state_cls)) + # Create a new root state tree with all substates instantiated. + fresh_root_state = root_state_cls(_reflex_internal_init=True) + if root_state is None: + root_state = fresh_root_state + elif not isinstance(root_state, BaseState): + msg = "Deserialized state is not an instance of BaseState, cannot populate substates." + raise TypeError(msg) + else: + # Ensure all substates exist, even if they were not serialized previously. + root_state.substates = fresh_root_state.substates + await self.populate_substates(token, root_state, root_state) + self.states[token.cache_key] = root_state + return cast(TOKEN_TYPE, root_state) + # For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls. + state = await self.load_state(token) + if state is None: + state = token.cls() + self.states[token.cache_key] = state + return cast(TOKEN_TYPE, state) + + async def set_state_for_substate( + self, token: StateToken[TOKEN_TYPE], substate: TOKEN_TYPE + ): """Set the state for a substate. Args: - client_token: The client token. + token: The token used to identify the state object. substate: The substate to set. """ - substate_token = _substate_key(client_token, substate) + substate_token = token.with_cls(type(substate)) - if substate._get_was_touched(): - substate._was_touched = False # Reset the touched flag after serializing. - pickle_state = substate._serialize() + if token.get_and_reset_touched_state(substate): + pickle_state = token.serialize(substate) if pickle_state: if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) @@ -205,8 +220,9 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState): lambda: self.token_path(substate_token).write_bytes(pickle_state), ) - for substate_substate in substate.substates.values(): - await self.set_state_for_substate(client_token, substate_substate) + if isinstance(token, BaseStateToken) and isinstance(substate, BaseState): + for substate_substate in substate.substates.values(): + await self.set_state_for_substate(token, substate_substate) async def _process_write_queue_delay(self): """Wait for the debounce period before processing the write queue again.""" @@ -252,15 +268,14 @@ async def _process_write_queue(self): ) for item in items_to_write: token = item.token - client_token, _ = _split_substate_key(token) await self.set_state_for_substate( - client_token, self._write_queue.pop(token).state + token, self._write_queue.pop(token).state ) # Check for expired states to purge. - for token, last_touched in list(self._token_last_touched.items()): + for cache_key, last_touched in list(self._token_last_touched.items()): if now - last_touched > self.token_expiration: - self._token_last_touched.pop(token) - self.states.pop(token, None) + self._token_last_touched.pop(cache_key) + self.states.pop(cache_key, None) await run_in_thread(self._purge_expired_states) await self._process_write_queue_delay() except asyncio.CancelledError: # noqa: PERF203 @@ -283,10 +298,8 @@ async def _flush_write_queue(self): f"StateManagerDisk._flush_write_queue: writing {n_outstanding_items} remaining items to disk" ) for item in outstanding_items: - token = item.token - client_token, _ = _split_substate_key(token) await self.set_state_for_substate( - client_token, + item.token, item.state, ) console.debug( @@ -306,7 +319,10 @@ async def _schedule_process_write_queue(self): @override async def set_state( - self, token: str, state: BaseState, **context: Unpack[StateModificationContext] + self, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, + **context: Unpack[StateModificationContext], ): """Set the state for a token. @@ -315,26 +331,26 @@ async def set_state( state: The state to set. context: The state modification context. """ - client_token, _ = _split_substate_key(token) + token = self._coerce_token(token) if self._write_debounce_seconds > 0: # Deferred write to reduce disk IO overhead. - if client_token not in self._write_queue: - self._write_queue[client_token] = QueueItem( - token=client_token, + if token not in self._write_queue: + self._write_queue[token] = QueueItem( + token=token, state=state, timestamp=time.time(), ) else: # Immediate write to disk. - await self.set_state_for_substate(client_token, state) + await self.set_state_for_substate(token, state) # Ensure the processing task is scheduled to handle expirations and any deferred writes. await self._schedule_process_write_queue() @override @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -344,14 +360,15 @@ async def modify_state( Yields: The state for the token. """ + token = self._coerce_token(token) # Disk state manager ignores the substate suffix and always returns the top-level state. - client_token, _ = _split_substate_key(token) - if client_token not in self._states_locks: + lock_key = token.lock_key + if lock_key not in self._states_locks: async with self._state_manager_lock: - if client_token not in self._states_locks: - self._states_locks[client_token] = asyncio.Lock() + if lock_key not in self._states_locks: + self._states_locks[lock_key] = asyncio.Lock() - async with self._states_locks[client_token]: + async with self._states_locks[lock_key]: state = await self.get_state(token) yield state await self.set_state(token, state, **context) @@ -364,3 +381,7 @@ async def close(self): with contextlib.suppress(asyncio.CancelledError): await self._write_queue_task self._write_queue_task = None + # Dump unlocked locks. + for token, lock in tuple(self._states_locks.items()): + if not lock.locked(): + self._states_locks.pop(token) diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index ec898388ebd..07d4dc27926 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -5,6 +5,7 @@ import dataclasses import time from collections.abc import AsyncIterator +from typing import Any, cast from typing_extensions import Unpack, override @@ -13,7 +14,7 @@ StateModificationContext, _default_token_expiration, ) -from reflex.state import BaseState, _split_substate_key +from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken @dataclasses.dataclass @@ -24,7 +25,7 @@ class StateManagerMemory(StateManager): token_expiration: int = dataclasses.field(default_factory=_default_token_expiration) # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + states: dict[str, Any] = dataclasses.field(default_factory=dict) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) @@ -35,15 +36,15 @@ class StateManagerMemory(StateManager): init=False, ) - # The latest expiration deadline for each token. - _token_expires_at: dict[str, float] = dataclasses.field( + # The latest expiration deadline and token for each cache key. + _token_expires_at: dict[str, tuple[float, StateToken]] = dataclasses.field( default_factory=dict, init=False, ) _expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False) - def _get_or_create_state(self, token: str) -> BaseState: + def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: """Get an existing state or create a fresh one for a token. Args: @@ -52,21 +53,33 @@ def _get_or_create_state(self, token: str) -> BaseState: Returns: The state for the token. """ - state = self.states.get(token) - if state is None: - state = self.states[token] = self.state(_reflex_internal_init=True) - return state - - def _track_token(self, token: str): + key = token.cache_key + if key not in self.states: + if isinstance(token, BaseStateToken): + self.states[key] = token.cls.get_root_state()( + _reflex_internal_init=True + ) + else: + self.states[key] = token.cls() + return cast(TOKEN_TYPE, self.states[key]) + + def _track_token(self, token: StateToken): """Refresh the expiration deadline for an active token.""" - self._token_expires_at[token] = time.time() + self.token_expiration + self._token_expires_at[token.cache_key] = ( + time.time() + self.token_expiration, + token, + ) self._ensure_expiration_task() - def _purge_token(self, token: str): - """Remove a token from in-memory state bookkeeping.""" - self._token_expires_at.pop(token, None) - self.states.pop(token, None) - self._states_locks.pop(token, None) + def _purge_token(self, token: StateToken): + """Remove a token from in-memory state bookkeeping. + + Args: + token: The token to purge. + """ + self._token_expires_at.pop(token.cache_key, None) + self._states_locks.pop(token.lock_key, None) + self.states.pop(token.cache_key, None) def _purge_expired_tokens(self) -> float | None: """Purge expired in-memory state entries and return the next deadline. @@ -79,9 +92,9 @@ def _purge_expired_tokens(self) -> float | None: token_expires_at = self._token_expires_at state_locks = self._states_locks - for token, expires_at in list(token_expires_at.items()): + for _cache_key, (expires_at, token) in list(token_expires_at.items()): if ( - state_lock := state_locks.get(token) + state_lock := state_locks.get(token.lock_key) ) is not None and state_lock.locked(): continue if expires_at <= now: @@ -92,7 +105,7 @@ def _purge_expired_tokens(self) -> float | None: return next_expires_at - async def _get_state_lock(self, token: str) -> asyncio.Lock: + async def _get_state_lock(self, token: StateToken) -> asyncio.Lock: """Get or create the lock for a token. Args: @@ -101,12 +114,12 @@ async def _get_state_lock(self, token: str) -> asyncio.Lock: Returns: The lock protecting the token's state. """ - state_lock = self._states_locks.get(token) + state_lock = self._states_locks.get(token.lock_key) if state_lock is None: async with self._state_manager_lock: - state_lock = self._states_locks.get(token) + state_lock = self._states_locks.get(token.lock_key) if state_lock is None: - state_lock = self._states_locks[token] = asyncio.Lock() + state_lock = self._states_locks[token.lock_key] = asyncio.Lock() return state_lock async def _expire_states(self): @@ -130,7 +143,7 @@ def _ensure_expiration_task(self): ) @override - async def get_state(self, token: str) -> BaseState: + async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -139,8 +152,7 @@ async def get_state(self, token: str) -> BaseState: Returns: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] + token = self._coerce_token(token) state = self._get_or_create_state(token) self._track_token(token) return state @@ -148,8 +160,8 @@ async def get_state(self, token: str) -> BaseState: @override async def set_state( self, - token: str, - state: BaseState, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, **context: Unpack[StateModificationContext], ): """Set the state for a token. @@ -159,15 +171,15 @@ async def set_state( state: The state to set. context: The state modification context. """ - token = _split_substate_key(token)[0] - self.states[token] = state + token = self._coerce_token(token) + self.states[token.cache_key] = state self._track_token(token) @override @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -177,8 +189,7 @@ async def modify_state( Yields: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] + token = self._coerce_token(token) state_lock = await self._get_state_lock(token) try: @@ -204,3 +215,7 @@ async def close(self): with contextlib.suppress(asyncio.CancelledError): await self._expiration_task self._expiration_task = None + # Dump unlocked locks. + for token, lock in tuple(self._states_locks.items()): + if not lock.locked(): + self._states_locks.pop(token) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index ee4e26bd761..3a98e16c0d4 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -9,7 +9,7 @@ import time import uuid from collections.abc import AsyncIterator -from typing import TypedDict +from typing import Any, TypedDict, cast from redis import ResponseError from redis.asyncio import Redis @@ -28,7 +28,8 @@ StateModificationContext, _default_token_expiration, ) -from reflex.state import BaseState, _split_substate_key, _substate_key +from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken +from reflex.state import BaseState from reflex.utils.tasks import ensure_task @@ -138,9 +139,7 @@ class StateManagerRedis(StateManager): ) # Cached states - _cached_states: dict[str, BaseState] = dataclasses.field( - default_factory=dict, init=False - ) + _cached_states: dict[str, Any] = dataclasses.field(default_factory=dict, init=False) _cached_states_locks: dict[str, asyncio.Lock] = dataclasses.field( default_factory=dict, init=False ) @@ -265,32 +264,32 @@ def _get_populated_states( @override async def get_state( self, - token: str, + token: StateToken[TOKEN_TYPE], top_level: bool = True, for_state_instance: BaseState | None = None, - ) -> BaseState: + ) -> TOKEN_TYPE: """Get the state for a token. Args: token: The token to get the state for. - top_level: If true, return an instance of the top-level state (self.state). + top_level: If true, return the top-level root state. for_state_instance: If provided, attach the requested states to this existing state tree. Returns: The state for the token. Raises: - RuntimeError: when the state_cls is not specified in the token, or when the parent state for a - requested state was not fetched. + RuntimeError: when the parent state for a requested state was not fetched. """ - # Split the actual token from the fully qualified substate name. - token, state_path = _split_substate_key(token) - if state_path: - # Get the State class associated with the given path. - requested_state_cls = self.state.get_class_substate(state_path) - else: - msg = f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" - raise RuntimeError(msg) + token = self._coerce_token(token) + if not isinstance(token, BaseStateToken): + # Non-BaseState token: simple single-key fetch. + redis_data = await self.redis.get(str(token)) + if redis_data is not None: + return token.deserialize(data=redis_data) + return token.cls() + + requested_state_cls = token.cls # Determine which states we already have. flat_state_tree: dict[str, BaseState] = ( @@ -306,7 +305,7 @@ async def get_state( redis_pipeline = self.redis.pipeline() for state_cls in required_state_classes: - redis_pipeline.get(_substate_key(token, state_cls)) + redis_pipeline.get(str(token.with_cls(state_cls))) for state_cls, redis_state in zip( required_state_classes, @@ -344,14 +343,17 @@ async def get_state( # To retain compatibility with previous implementation, by default, we return # the top-level state which should always be fetched or already cached. if top_level: - return flat_state_tree[self.state.get_full_name()] - return flat_state_tree[requested_state_cls.get_full_name()] + return cast( + TOKEN_TYPE, + flat_state_tree[requested_state_cls.get_root_state().get_full_name()], + ) + return cast(TOKEN_TYPE, flat_state_tree[requested_state_cls.get_full_name()]) @override async def set_state( self, - token: str, - state: BaseState, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, *, lock_id: bytes | None = None, **context: Unpack[StateModificationContext], @@ -361,13 +363,14 @@ async def set_state( Args: token: The token to set the state for. state: The state to set. - lock_id: If provided, the lock_key must be set to this value to set the state. + lock_id: If provided, the lock must be held with this value to set the state. context: The event context. Raises: LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. RuntimeError: If the state instance doesn't match the state name in the token. """ + token = self._coerce_token(token) # Check that we're holding the lock. if ( lock_id is not None @@ -387,9 +390,18 @@ async def set_state( ) raise LockExpiredError(msg) - client_token, substate_name = _split_substate_key(token) + if not isinstance(token, BaseStateToken): + # Non-BaseState token: simple single-key write. + pickle_state = token.serialize(state) + if pickle_state: + await self.redis.set(str(token), pickle_state, ex=self.token_expiration) + return + + base_state = cast(BaseState, state) + + lock_key = token.lock_key - if lock_id is not None and client_token not in self._local_leases: + if lock_id is not None and lock_key not in self._local_leases: time_taken = ( self.lock_expiration - (await self.redis.pttl(self._lock_key(token))) ) / 1000 @@ -405,30 +417,25 @@ async def set_state( dedupe=True, ) - # If the substate name on the token doesn't match the instance name, it cannot have a parent. - if state.parent_state is not None and state.get_full_name() != substate_name: - msg = f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." - raise RuntimeError(msg) - # Recursively set_state on all known substates. tasks = [ asyncio.create_task( self.set_state( - _substate_key(client_token, substate), + token, substate, lock_id=lock_id, **context, ), - name=f"reflex_set_state|{client_token}|{substate.get_full_name()}", + name=f"reflex_set_state|{lock_key}|{substate.get_full_name()}", ) - for substate in state.substates.values() + for substate in base_state.substates.values() ] # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). - if state._get_was_touched(): - pickle_state = state._serialize() + if base_state._get_was_touched(): + pickle_state = base_state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + str(token.with_cls(type(base_state))), pickle_state, ex=self.token_expiration, ) @@ -439,8 +446,8 @@ async def set_state( @contextlib.asynccontextmanager async def _try_modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState | None]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE | None]: """Modify the state for a token while holding exclusive lock. Args: @@ -467,7 +474,7 @@ async def _try_modify_state( return # Opportunistic locking is enabled, so try to hold the lock across multiple calls. - client_token, _ = _split_substate_key(token) + lock_key = token.lock_key lock_held_ctx = contextlib.AsyncExitStack() try: lock_id = await lock_held_ctx.enter_async_context( @@ -479,12 +486,12 @@ async def _try_modify_state( else: # Do not create a lease break task when multiple instances are waiting. if ( - not await self._get_local_lease(client_token) + not await self._get_local_lease(lock_key) and await self._n_lock_contenders(self._lock_key(token)) > 0 ): if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} has contention, not leasing" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} has contention, not leasing" ) async with lock_held_ctx: state = await self.get_state(token) @@ -498,11 +505,11 @@ async def _try_modify_state( token, lock_id, cleanup_ctx=lock_held_ctx, **context ) ) is ( - current_lease_task := await self._get_local_lease(client_token) + current_lease_task := await self._get_local_lease(lock_key) ) and new_lease_task is not None: if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} obtained lock {lock_id.decode()}." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} obtained lock {lock_id.decode()}." ) elif current_lease_task is None: # Check if we still have the redis lock, then just try to send this one update and release it. @@ -510,7 +517,7 @@ async def _try_modify_state( if await self.redis.get(self._lock_key(token)) == lock_id: if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..." ) async with lock_held_ctx: state = await self.get_state(token) @@ -519,7 +526,7 @@ async def _try_modify_state( return elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lock {lock_id.decode()} expired while waiting for lease task to exit..." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lock {lock_id.decode()} expired while waiting for lease task to exit..." ) # Have to retry getting the state, but now it's probably cached. yield None @@ -527,8 +534,8 @@ async def _try_modify_state( @override @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -538,14 +545,17 @@ async def modify_state( Yields: The state for the token. """ + token = self._coerce_token(token) while True: async with self._try_modify_state(token, **context) as state_instance: if state_instance is not None: - yield state_instance + yield cast(TOKEN_TYPE, state_instance) return @contextlib.asynccontextmanager - async def _get_state_cached(self, token: str) -> AsyncIterator[BaseState | None]: + async def _get_state_cached( + self, token: StateToken[TOKEN_TYPE] + ) -> AsyncIterator[TOKEN_TYPE | None]: """Get the cached state for a token, while holding the local lease lock. Args: @@ -553,40 +563,41 @@ async def _get_state_cached(self, token: str) -> AsyncIterator[BaseState | None] Yields: The cached state for the token, or None if not cached/uncachable. - - Raises: - RuntimeError: when the state_cls is not specified in the token. """ - client_token, state_path = _split_substate_key(token) + lock_key = token.lock_key # Opportunistically reuse existing lock. if ( - client_token in self._local_leases - and (state_lock := self._cached_states_locks.get(client_token)) is not None + lock_key in self._local_leases + and (state_lock := self._cached_states_locks.get(lock_key)) is not None ): async with state_lock: - if await self._get_local_lease(client_token) is not None: - if ( - cached_state := self._cached_states.get(client_token) - ) is not None: - # Make sure we have the substate cached (or fetch it from redis). - try: - substate = cached_state.get_substate(state_path.split(".")) - if len(substate.substates) != len( - type(substate).get_substates() - ): - # If the substate is missing substates, we need to refetch it. - raise ValueError # noqa: TRY301 - except ValueError: - await self.get_state(token, for_state_instance=cached_state) - yield cached_state + if await self._get_local_lease(lock_key) is not None: + if (cached_state := self._cached_states.get(lock_key)) is not None: + if isinstance(token, BaseStateToken): + # Make sure we have the substate cached (or fetch it from redis). + state_path = token.cls.get_full_name() + try: + substate = cached_state.get_substate( + state_path.split(".") + ) + if len(substate.substates) != len( + type(substate).get_substates() + ): + # If the substate is missing substates, we need to refetch it. + raise ValueError # noqa: TRY301 + except ValueError: + await self.get_state( + token, for_state_instance=cached_state + ) + yield cast(TOKEN_TYPE, cached_state) return elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease task found, lock held, but no cached state" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease task found, lock held, but no cached state" ) elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} no active lease task found" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} no active lease task found" ) yield None @@ -608,7 +619,7 @@ def _notify_next_waiter(self, key: bytes): async def _create_lease_break_task( self, - token: str, + token: StateToken[TOKEN_TYPE], lock_id: bytes, cleanup_ctx: contextlib.AsyncExitStack, **context: Unpack[StateModificationContext], @@ -626,32 +637,32 @@ async def _create_lease_break_task( """ self._ensure_lock_task() - client_token, _ = _split_substate_key(token) + lock_key = token.lock_key async def do_flush() -> None: - if (state_lock := self._cached_states_locks.get(client_token)) is None: + if (state_lock := self._cached_states_locks.get(lock_key)) is None: # If we lost the lock, we can't write the state, something went wrong. console.warn( - f"State lock for {client_token} missing while finalizing lease." + f"State lock for {lock_key} missing while finalizing lease." ) return async with state_lock: # Write the state to redis while no one else can modify the cached copy. - state = self._cached_states.pop(client_token, None) + state = self._cached_states.pop(lock_key, None) try: if state: if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} flushing state" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} flushing state" ) await self.set_state(token, state, lock_id=lock_id, **context) finally: - if (current_lease := self._local_leases.get(client_token)) is task: - self._local_leases.pop(client_token, None) + if (current_lease := self._local_leases.get(lock_key)) is task: + self._local_leases.pop(lock_key, None) # TODO: clean up the cached states locks periodically elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}." ) async def lease_breaker(): @@ -660,7 +671,7 @@ async def lease_breaker(): lease_break_time = self.oplock_hold_time_ms / 1000 if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s" ) try: await asyncio.sleep(lease_break_time) @@ -669,7 +680,7 @@ async def lease_breaker(): # We got cancelled so if someone is holding the lock, # extend the timeout so they get the full time to complete. if ( - state_lock := self._cached_states_locks[client_token] + state_lock := self._cached_states_locks[lock_key] ) is not None and state_lock.locked(): await self._try_extend_lock(self._lock_key(token)) try: @@ -688,10 +699,10 @@ async def lease_breaker(): if cancelled_error is not None: raise cancelled_error - if (state_lock := self._cached_states_locks.get(client_token)) is not None: + if (state_lock := self._cached_states_locks.get(lock_key)) is not None: # We have an existing lock, so lets see if we have an existing lease to cancel. async with state_lock: - if (existing_task := self._local_leases.get(client_token)) is not None: + if (existing_task := self._local_leases.get(lock_key)) is not None: # There's already a lease break task, so cancel it to clear it out. existing_task.cancel() if existing_task is not None: @@ -699,30 +710,28 @@ async def lease_breaker(): await existing_task # Now we might need to create a new lock. - if (state_lock := self._cached_states_locks.get(client_token)) is None: + if (state_lock := self._cached_states_locks.get(lock_key)) is None: async with self._state_manager_lock: - if (state_lock := self._cached_states_locks.get(client_token)) is None: - state_lock = self._cached_states_locks[client_token] = ( - asyncio.Lock() - ) + if (state_lock := self._cached_states_locks.get(lock_key)) is None: + state_lock = self._cached_states_locks[lock_key] = asyncio.Lock() async with state_lock: # Create the task now if one didn't sneak past us. if ( - client_token not in self._local_leases + lock_key not in self._local_leases and await self._n_lock_contenders(self._lock_key(token)) == 0 ): - self._local_leases[client_token] = task = asyncio.create_task( + self._local_leases[lock_key] = task = asyncio.create_task( lease_breaker(), - name=f"reflex_lease_breaker|{client_token}|{lock_id.decode()}", + name=f"reflex_lease_breaker|{lock_key}|{lock_id.decode()}", ) # Fetch the requested state into the cache. - self._cached_states[client_token] = await self.get_state(token) + self._cached_states[lock_key] = await self.get_state(token) return task return None @staticmethod - def _lock_key(token: str) -> bytes: + def _lock_key(token: StateToken[Any]) -> bytes: """Get the redis key for a token's lock. Args: @@ -731,9 +740,7 @@ def _lock_key(token: str) -> bytes: Returns: The redis lock key for the token. """ - # All substates share the same lock domain, so ignore any substate path suffix. - client_token = _split_substate_key(token)[0] - return f"{client_token}_lock".encode() + return f"{token.lock_key}_lock".encode() async def _try_extend_lock(self, lock_key: bytes) -> bool | None: """Extends the current lock for another lock_expiration period. @@ -1043,7 +1050,7 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: @contextlib.asynccontextmanager async def _lock( - self, token: str, event_name: str | None = None + self, token: StateToken[Any], event_name: str | None = None ) -> AsyncIterator[bytes]: """Obtain a redis lock for a token. diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py new file mode 100644 index 00000000000..b376f76e6d9 --- /dev/null +++ b/reflex/istate/manager/token.py @@ -0,0 +1,244 @@ +"""Representation of a StateManager token.""" + +from __future__ import annotations + +import dataclasses +import pickle +from typing import TYPE_CHECKING, BinaryIO, Generic, TypeVar + +from typing_extensions import Self + +from reflex.utils import console + +if TYPE_CHECKING: + from reflex.state import BaseState + +TOKEN_TYPE = TypeVar("TOKEN_TYPE") + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class StateToken(Generic[TOKEN_TYPE]): + """Token for looking referencing a state instance in the StateManager.""" + + # Identifier, usually the client_token, but could be a linked / shared token. + ident: str + + # The class associated with the state instance. + cls: type[TOKEN_TYPE] + + def with_cls(self, cls: type[TOKEN_TYPE]) -> Self: + """Return a new token with the cls field updated to the provided class. + + Args: + cls: The class to update the cls field to. + + Returns: + A new StateToken instance with the updated cls field. + """ + return dataclasses.replace(self, cls=cls) + + @property + def cache_key(self) -> str: + """The key used for caching state instances in the StateManager. + + Returns: + A string key combining ident and class path. + """ + return str(self) + + @property + def lock_key(self) -> str: + """The key used for locking and session-level bookkeeping. + + Returns: + The token ident. + """ + return self.ident + + def __str__(self) -> str: + """The key used in the underlying StateManager store. + + Returns: + A string representation of the token, which is a combination of the ident and cls name. + """ + # urlencode the redis token to escape the slash delimiter. + clean_ident = self.ident.replace("/", "%2F") + clean_cls_name = f"{self.cls.__module__}.{self.cls.__name__}".replace( + "/", "%2F" + ) + return f"{clean_ident}/{clean_cls_name}" + + @classmethod + def serialize(cls, state: TOKEN_TYPE) -> bytes: + """Serialize the state for redis/disk storage. + + Args: + state: The state to serialize. + + Returns: + The serialized state. + """ + return pickle.dumps(state) + + @classmethod + def deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> TOKEN_TYPE: + """Deserialize the state from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized state instance. + """ + if data is not None and fp is not None: + msg = "Only one of `data` or `fp` may be provided, not both." + raise ValueError(msg) + if data is not None: + return pickle.loads(data) + if fp is not None: + return pickle.load(fp) + msg = "At least one of `data` or `fp` must be provided." + raise ValueError(msg) + + @classmethod + def get_and_reset_touched_state(cls, state: TOKEN_TYPE) -> bool: + """Get the touched state and reset the touched flag. + + This is used to determine if a state has been modified since it was last serialized. + + Args: + state: The state to check for modifications. + + Returns: + The touched state of the state. + """ + # Default implementation is always to write the state. + return True + + +class BaseStateToken(StateToken["BaseState"]): + """A token for the accessing reflex BaseState instances. + + This token type implies subtree hierarchy population and other semantic checks. + """ + + @property + def cache_key(self) -> str: + """The key used for caching state instances in the StateManager. + + BaseState tokens use just the ident because the entire state hierarchy + lives under a single root state instance per session. + + Returns: + The token ident. + """ + return self.ident + + def with_cls(self, cls: type[BaseState]) -> Self: + """Return a new token with the cls field updated to the provided class. + + Args: + cls: The class to update the cls field to. + + Returns: + A new StateToken instance with the updated cls field. + """ + return super().with_cls(cls) + + def __str__(self) -> str: + """The key used in the underlying StateManager store. + + Returns: + A string representation of the token, which is a combination of the ident and cls name. + """ + # urlencode the redis token to escape the slash delimiter. + return f"{self.ident}_{self.cls.get_full_name()}" + + @classmethod + def serialize(cls, state: BaseState) -> bytes: + """Serialize the BaseState for redis/disk storage. + + Args: + state: The BaseState to serialize. + + Returns: + The serialized state. + """ + return state._serialize() + + @classmethod + def deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> BaseState: + """Deserialize the BaseState from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized BaseState instance. + """ + from reflex.state import BaseState + + return BaseState._deserialize(data, fp) + + @classmethod + def get_and_reset_touched_state(cls, state: BaseState) -> bool: + """Get the touched state and reset the touched flag. + + This is used to determine if a state has been modified since it was last serialized. + + Args: + state: The BaseState to check for modifications. + + Returns: + The touched state of the BaseState. + """ + was_touched = state._get_was_touched() + state._was_touched = False # Reset the touched flag after serializing. + return was_touched + + @classmethod + def from_legacy_token( + cls, legacy_token: str, root_state: type[BaseState] | None + ) -> Self: + """Create a BaseStateToken from a legacy token string. + + The legacy token format is "{ident}_{module_path}.{class_name}". + + Args: + legacy_token: The legacy token string to convert. + root_state: The root state instance. + + Returns: + A BaseStateToken instance created from the legacy token. + + Raises: + ValueError: If the legacy token format is invalid or if the state class cannot be found + """ + from reflex.state import _split_substate_key + + if root_state is None: + msg = ( + "Root state must be provided to convert legacy token to BaseStateToken." + ) + raise ValueError(msg) + + console.deprecate( + feature_name="Passing a string to modify_state", + reason="Use rx.BaseStateToken(token, state_cls) instead of the legacy string format", + deprecation_version="0.9.0", + removal_version="1.0", + ) + + client_token, state_path = _split_substate_key(legacy_token) + state_cls = root_state.get_class_substate(tuple(state_path.split("."))) # type: ignore[union-attr] + return cls(ident=client_token, cls=state_cls) diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index cbb62aec554..ce9aa3c8618 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -16,12 +16,13 @@ import wrapt from reflex_base.event import Event +from reflex_base.event.context import EventContext from reflex_base.utils.exceptions import ImmutableStateError from reflex_base.utils.serializers import can_serialize, serialize, serializer from reflex_base.vars.base import Var from typing_extensions import Self -from reflex.utils import prerequisites +from reflex.istate.manager.token import BaseStateToken if TYPE_CHECKING: from reflex.state import BaseState, StateUpdate @@ -73,15 +74,12 @@ def __init__( event: The event associated with the state modification context. parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ - from reflex.state import _substate_key - super().__init__(state_instance) self._self_event = event - self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) - self._self_substate_token = _substate_key( - state_instance.router.session.client_token, - self._self_substate_path, + self._self_substate_token = BaseStateToken( + ident=EventContext.get().token, + cls=state_instance.__class__, ) self._self_actx = None self._self_mutable = False @@ -135,11 +133,13 @@ async def __aenter__(self) -> Self: msg = "The state is already mutable. Do not nest `async with self` blocks." raise ImmutableStateError(msg) + ctx = EventContext.get() + await self._self_actx_lock.acquire() try: self._self_actx_lock_holder = current_task - self._self_actx = self._self_app.modify_state( - token=self._self_substate_token, background=True, event=self._self_event + self._self_actx = ctx.state_manager.modify_state_with_links( + token=self._self_substate_token, event=self._self_event ) mutable_state = await self._self_actx.__aenter__() self._self_mutable = True @@ -165,12 +165,22 @@ async def __aexit__(self, *exc_info: Any) -> None: return try: if self._self_mutable and self._self_actx is not None: - await self._self_actx.__aexit__(*exc_info) + root_state = self.__wrapped__._get_root_state() + delta = await root_state._get_resolved_delta() + root_state._clean() + # When the frontend vars are modified emit the delta to the frontend. + if delta: + ctx = EventContext.get() + await ctx.emit_delta(delta) finally: - self._self_actx = None - self._self_mutable = False - self._self_actx_lock_holder = None - self._self_actx_lock.release() + try: + if self._self_mutable and self._self_actx is not None: + await self._self_actx.__aexit__(*exc_info) + finally: + self._self_actx = None + self._self_mutable = False + self._self_actx_lock_holder = None + self._self_actx_lock.release() def __enter__(self): """Enter the regular context manager protocol. diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 284b8e2fbac..30c3b5c5fee 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -3,14 +3,16 @@ import asyncio import contextlib from collections.abc import AsyncIterator -from typing import Self, TypeVar +from typing import TypeVar from reflex_base.constants import ROUTER_DATA from reflex_base.event import Event, get_hydrate_event from reflex_base.utils import console from reflex_base.utils.exceptions import ReflexRuntimeError +from typing_extensions import Self -from reflex.state import BaseState, State, _override_base_method, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState, State, _override_base_method UPDATE_OTHER_CLIENT_TASKS: set[asyncio.Task] = set() LINKED_STATE = TypeVar("LINKED_STATE", bound="SharedStateBaseInternal") @@ -53,7 +55,7 @@ def _do_update_other_tokens( async def _update_client(token: str): async with app.modify_state( - _substate_key(token, state_type), + BaseStateToken(ident=token, cls=state_type), previous_dirty_vars=previous_dirty_vars, ): pass @@ -166,7 +168,6 @@ def _rehydrate(self): """ return [ Event( - token=self.router.session.client_token, name=get_hydrate_event(self._get_root_state()), ), State.set_is_hydrated(True), @@ -237,7 +238,10 @@ async def _unlink(self): # Patch in the original state, apply updates, then rehydrate. private_root_state = await get_state_manager().get_state( - _substate_key(self.router.session.client_token, type(self)) + BaseStateToken( + ident=self.router.session.client_token, + cls=type(self), + ) ) private_state = await private_root_state.get_state(type(self)) async with _patch_state( @@ -272,12 +276,14 @@ async def _internal_patch_linked_state( # Get the newly linked state and update pointers/delta for subsequent events. if token not in self._held_locks: linked_root_state = await self._exit_stack.enter_async_context( - get_state_manager().modify_state(_substate_key(token, type(self))) + get_state_manager().modify_state( + BaseStateToken(ident=token, cls=type(self)) + ) ) self._held_locks.setdefault(token, {}) else: linked_root_state = await get_state_manager().get_state( - _substate_key(token, type(self)) + BaseStateToken(ident=token, cls=type(self)) ) linked_state = await linked_root_state.get_state(type(self)) if not isinstance(linked_state, SharedState): diff --git a/reflex/istate/wrappers.py b/reflex/istate/wrappers.py index 5896b378cb8..09e0a4fa3a7 100644 --- a/reflex/istate/wrappers.py +++ b/reflex/istate/wrappers.py @@ -3,8 +3,9 @@ from typing import Any from reflex.istate.manager import get_state_manager +from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import ReadOnlyStateProxy -from reflex.state import _split_substate_key, _substate_key +from reflex.state import State, _split_substate_key async def get_state(token: str, state_cls: Any | None = None) -> ReadOnlyStateProxy: @@ -19,9 +20,9 @@ async def get_state(token: str, state_cls: Any | None = None) -> ReadOnlyStatePr """ mng = get_state_manager() if state_cls is not None: - root_state = await mng.get_state(_substate_key(token, state_cls)) + root_state = await mng.get_state(BaseStateToken(ident=token, cls=state_cls)) else: - root_state = await mng.get_state(token) + root_state = await mng.get_state(BaseStateToken(ident=token, cls=State)) _, state_path = _split_substate_key(token) state_cls = root_state.get_class_substate(tuple(state_path.split("."))) instance = await root_state.get_state(state_cls) diff --git a/reflex/state.py b/reflex/state.py index 98004fa297f..463a1057e50 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -7,19 +7,14 @@ import contextlib import copy import dataclasses -import datetime import functools import inspect import pickle import re import sys import time -import uuid -import warnings -from collections.abc import AsyncIterator, Callable, Iterator, Sequence -from enum import Enum +from collections.abc import Callable, Iterator, Mapping, Sequence from hashlib import md5 -from importlib.util import find_spec from types import FunctionType from typing import ( TYPE_CHECKING, @@ -28,7 +23,6 @@ ClassVar, ParamSpec, TypeVar, - cast, get_type_hints, ) @@ -42,7 +36,6 @@ EventHandler, EventSpec, call_script, - fix_events, ) from reflex_base.utils.exceptions import ( ComputedVarShadowsBaseVarsError, @@ -59,7 +52,8 @@ UnretrievableVarValueError, ) from reflex_base.utils.exceptions import ImmutableStateError as ImmutableStateError -from reflex_base.utils.types import _isinstance, is_union, value_inside_optional +from reflex_base.utils.serializers import serializer +from reflex_base.utils.types import _isinstance from reflex_base.vars import Field, VarData, field from reflex_base.vars.base import ( ComputedVar, @@ -78,9 +72,8 @@ from reflex.istate import HANDLED_PICKLE_ERRORS, debug_failed_pickles from reflex.istate.data import RouterData from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy -from reflex.istate.proxy import MutableProxy, StateProxy, is_mutable_type +from reflex.istate.proxy import MutableProxy, is_mutable_type from reflex.istate.storage import ClientStorageBase -from reflex.model import Model from reflex.utils import console, format, prerequisites, types from reflex.utils.exec import is_testing_env @@ -88,7 +81,8 @@ from reflex_base.components.component import Component -Delta = dict[str, Any] +Delta = dict[str, dict[str, Any]] +DeltaMapping = Mapping[str, Mapping[str, Any]] var = computed_var @@ -179,8 +173,6 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]: class EventHandlerSetVar(EventHandler): """A special event handler to wrap setvar functionality.""" - state_cls: type[BaseState] = dataclasses.field(init=False) - def __init__(self, state_cls: type[BaseState]): """Initialize the EventHandlerSetVar. @@ -189,9 +181,8 @@ def __init__(self, state_cls: type[BaseState]): """ super().__init__( fn=type(self).setvar, - state_full_name=state_cls.get_full_name(), + state=state_cls, ) - object.__setattr__(self, "state_cls", state_cls) def __hash__(self): """Get the hash of the event handler. @@ -203,7 +194,7 @@ def __hash__(self): tuple(self.event_actions.items()), self.fn, self.state_full_name, - self.state_cls, + self.state, )) def setvar(self, var_name: str, value: Any): @@ -235,11 +226,11 @@ def __call__(self, *args: Any) -> EventSpec: from reflex_base.utils.exceptions import EventHandlerValueError config = get_config() - if config.state_auto_setters is None: + if config.state_auto_setters is None and self.state is not None: console.deprecate( feature_name="state_auto_setters defaulting to True", reason="The default value will be changed to False in a future release. Set state_auto_setters explicitly or define setters explicitly. " - f"Used {self.state_cls.__name__}.setvar without defining it.", + f"Used {self.state.__name__}.setvar without defining it.", deprecation_version="0.8.9", removal_version="0.9.0", dedupe=True, @@ -250,11 +241,11 @@ def __call__(self, *args: Any) -> EventSpec: msg = f"Var name must be passed as a string, got {args[0]!r}" raise EventHandlerValueError(msg) - handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) + handler = getattr(self.state, constants.SETTER_PREFIX + args[0], None) # Check that the requested Var setter exists on the State at compile time. if handler is None: - msg = f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" + msg = f"Variable `{args[0]}` cannot be set on `{self.state_full_name}`" raise AttributeError(msg) if inspect.iscoroutinefunction(handler.fn): @@ -325,15 +316,6 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU return fn -_deserializers = { - int: int, - float: float, - datetime.datetime: datetime.datetime.fromisoformat, - datetime.date: datetime.date.fromisoformat, - datetime.time: datetime.time.fromisoformat, - uuid.UUID: uuid.UUID, -} - all_base_state_classes: dict[str, None] = {} CLASS_VAR_NAMES = frozenset({ @@ -344,7 +326,6 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU "backend_vars", "inherited_backend_vars", "event_handlers", - "class_subclasses", "_var_dependencies", "_always_dirty_computed_vars", "_always_dirty_substates", @@ -376,9 +357,6 @@ class BaseState(EvenMoreBasicBaseState): # The event handlers. event_handlers: ClassVar[builtins.dict[str, EventHandler]] = {} - # A set of subclassses of this class. - class_subclasses: ClassVar[set[type[BaseState]]] = set() - # Mapping of var name to set of (state_full_name, var_name) that depend on it. _var_dependencies: ClassVar[builtins.dict[str, set[tuple[str, str]]]] = {} @@ -518,6 +496,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): Raises: StateValueError: If a substate class shadows another. """ + from reflex_base.registry import RegistrationContext from reflex_base.utils.exceptions import StateValueError super().__init_subclass__(**kwargs) @@ -538,9 +517,6 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): # Computed vars should not shadow builtin state props. cls._check_overridden_basevars() - # Reset subclass tracking for this class. - cls.class_subclasses = set() - # Reset dirty substate tracking for this class. cls._always_dirty_substates = set() cls._potentially_dirty_states = set() @@ -552,15 +528,13 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.inherited_backend_vars = parent_state.backend_vars # Check if another substate class with the same name has already been defined. - if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}: + if cls.get_name() in {c.get_name() for c in parent_state.get_substates()}: # This should not happen, since we have added module prefix to state names in #3214 msg = ( f"The substate class '{cls.get_name()}' has been defined multiple times. " "Shadowing substate classes is not allowed." ) raise StateValueError(msg) - # Track this new subclass in the parent state's subclasses set. - parent_state.class_subclasses.add(cls) # Get computed vars. computed_vars = cls._get_computed_vars() @@ -644,6 +618,8 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.event_handlers[name] = handler setattr(cls, name, handler) + RegistrationContext.register_base_state(cls) + # Initialize per-class var dependency tracking. cls._var_dependencies = {} cls._init_var_dependency_dicts() @@ -984,7 +960,9 @@ def get_substates(cls) -> set[type[BaseState]]: Returns: The substates of the state. """ - return cls.class_subclasses + from reflex_base.registry import RegistrationContext + + return RegistrationContext.get().get_substates(cls) @classmethod @functools.lru_cache @@ -1138,7 +1116,7 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None): cls.vars.update({name: var}) # let substates know about the new variable - for substate_class in cls.class_subclasses: + for substate_class in cls.get_substates(): substate_class.vars.setdefault(name, var) # Reinitialize dependency tracking dicts. @@ -1167,12 +1145,17 @@ def _create_event_handler( Returns: The event handler. """ + from reflex_base.registry import RegistrationContext + # Check if function has stored event_actions from decorator event_actions = getattr(fn, EVENT_ACTIONS_MARKER, {}) - return event_handler_cls( - fn=fn, state_full_name=cls.get_full_name(), event_actions=event_actions - ) + handler = event_handler_cls(fn=fn, state=cls, event_actions=event_actions) + if cls.get_full_name() in all_base_state_classes: + # Register handlers created after the class was registered. + reg_ctx = RegistrationContext.get() + reg_ctx.register_event_handler(handler, states=(cls,)) + return handler @classmethod def _create_setvar(cls): @@ -1276,7 +1259,7 @@ def _update_substate_inherited_vars(cls, vars_to_add: dict[str, Var]): Args: vars_to_add: names to Var instances to add to substates """ - for substate_class in cls.class_subclasses: + for substate_class in cls.get_substates(): for name, var in vars_to_add.items(): if types.is_backend_base_variable(name, cls): substate_class.backend_vars.setdefault(name, var) @@ -1619,6 +1602,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: """ from reflex.istate.manager import get_state_manager from reflex.istate.manager.redis import StateManagerRedis + from reflex.istate.manager.token import BaseStateToken # Then get the target state and all its substates. state_manager = get_state_manager() @@ -1629,7 +1613,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: ) raise RuntimeError(msg) state_in_redis = await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), + token=BaseStateToken(ident=self.router.session.client_token, cls=state_cls), top_level=False, for_state_instance=self, ) @@ -1720,289 +1704,6 @@ async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: ) return getattr(other_state, var_data.field_name) - def _get_event_handler(self, event: Event | str) -> tuple[BaseState, EventHandler]: - """Get the event handler for the given event. - - Args: - event: The event to get the handler for, or a dotted handler name string. - - - Returns: - The event handler. - - Raises: - ValueError: If the event handler or substate is not found. - """ - # Get the event handler. - name = event.name if isinstance(event, Event) else event - path = name.split(".") - path, name = path[:-1], path[-1] - substate = self.get_substate(path) - if not substate: - msg = "The value of state cannot be None when processing an event." - raise ValueError(msg) - handler = substate.event_handlers[name] - - return substate, handler - - async def _process(self, event: Event) -> AsyncIterator[StateUpdate]: - """Obtain event info and process event. - - Args: - event: The event to process. - - Yields: - The state update after processing the event. - """ - # Get the event handler. - substate, handler = self._get_event_handler(event) - - # For background tasks, proxy the state. - if handler.is_background: - substate = StateProxy(substate, event) - - # Run the event generator and yield state updates. - async for update in self._process_event( - handler=handler, - state=substate, - payload=event.payload, - ): - yield update - - def _check_valid(self, handler: EventHandler, events: Any) -> Any: - """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. - - Args: - handler: EventHandler. - events: The events to be checked. - - Returns: - The events as they are if valid. - - Raises: - TypeError: If any of the events are not valid. - """ - - def _is_valid_type(events: Any) -> bool: - return isinstance(events, (Event, EventHandler, EventSpec)) - - if events is None or _is_valid_type(events): - return events - - if not (isinstance(events, Sequence) and not isinstance(events, (str, bytes))): - events = [events] - - try: - if all(_is_valid_type(e) for e in events): - return events - except TypeError: - pass - - coroutines = [e for e in events if inspect.iscoroutine(e)] - - for coroutine in coroutines: - coroutine_name = coroutine.__qualname__ - warnings.filterwarnings( - "ignore", message=f"coroutine '{coroutine_name}' was never awaited" - ) - - msg = ( - f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references)." - f" Returned events of types {', '.join(map(str, map(type, events)))!s}." - ) - raise TypeError(msg) - - async def _as_state_update( - self, - handler: EventHandler, - events: EventSpec | list[EventSpec] | None, - final: bool, - ) -> StateUpdate: - """Convert the events to a StateUpdate. - - Fixes the events and checks for validity before converting. - - Args: - handler: The handler where the events originated from. - events: The events to queue with the update. - final: Whether the handler is done processing. - - Returns: - The valid StateUpdate containing the events and final flag. - """ - # get the delta from the root of the state tree - state = self._get_root_state() - - token = self.router.session.client_token - - # Convert valid EventHandler and EventSpec into Event - fixed_events = fix_events(self._check_valid(handler, events), token) - - try: - # Get the delta after processing the event. - delta = await state._get_resolved_delta() - state._clean() - - return StateUpdate( - delta=delta, - events=fixed_events, - final=final if not handler.is_background else None, - ) - except Exception as ex: - state._clean() - - event_specs = ( - prerequisites.get_and_validate_app().app.backend_exception_handler(ex) - ) - - if event_specs is None: - return StateUpdate() - - event_specs_correct_type = cast( - list[EventSpec | EventHandler] | None, - [event_specs] if isinstance(event_specs, EventSpec) else event_specs, - ) - fixed_events = fix_events( - event_specs_correct_type, - token, - router_data=state.router_data, - ) - return StateUpdate( - events=fixed_events, - final=True, - ) - - async def _process_event( - self, - handler: EventHandler, - state: BaseState | StateProxy, - payload: builtins.dict, - ) -> AsyncIterator[StateUpdate]: - """Process event. - - Args: - handler: EventHandler to process. - state: State to process the handler. - payload: The event payload. - - Yields: - StateUpdate object - - Raises: - ValueError: If a string value is received for an int or float type and cannot be converted. - """ - from reflex.utils import telemetry - - # Get the function to process the event. - fn = functools.partial(handler.fn, state) - - try: - type_hints = types.get_type_hints(handler.fn) - except Exception: - type_hints = {} - - for arg, value in list(payload.items()): - hinted_args = type_hints.get(arg, Any) - if hinted_args is Any: - continue - if is_union(hinted_args): - if value is None: - continue - hinted_args = value_inside_optional(hinted_args) - if ( - isinstance(value, dict) - and isinstance(hinted_args, type) - and not types.is_generic_alias(hinted_args) # py3.10 - ): - if issubclass(hinted_args, Model): - # Remove non-fields from the payload - payload[arg] = hinted_args(**{ - key: value - for key, value in value.items() - if key in hinted_args.__fields__ - }) - elif dataclasses.is_dataclass(hinted_args): - payload[arg] = hinted_args(**value) - elif find_spec("pydantic"): - from pydantic import BaseModel - - if issubclass(hinted_args, BaseModel): - payload[arg] = hinted_args.model_validate(value) - elif isinstance(value, list) and (hinted_args is set or hinted_args is set): - payload[arg] = set(value) - elif isinstance(value, list) and ( - hinted_args is tuple or hinted_args is tuple - ): - payload[arg] = tuple(value) - elif isinstance(hinted_args, type) and issubclass(hinted_args, Enum): - try: - payload[arg] = hinted_args(value) - except ValueError: - msg = f"Received an invalid enum value ({value}) for {arg} of type {hinted_args}" - raise ValueError(msg) from None - elif ( - isinstance(value, str) - and (deserializer := _deserializers.get(hinted_args)) is not None - ): - try: - payload[arg] = deserializer(value) - except ValueError: - msg = f"Received a string value ({value}) for {arg} but expected a {hinted_args}" - raise ValueError(msg) from None - else: - console.warn( - f"Received a string value ({value}) for {arg} but expected a {hinted_args}. A simple conversion was successful." - ) - - # Wrap the function in a try/except block. - try: - # Handle async functions. - if inspect.iscoroutinefunction(fn.func): - events = await fn(**payload) - - # Handle regular functions. - else: - events = fn(**payload) - # Handle async generators. - if inspect.isasyncgen(events): - async for event in events: - yield await state._as_state_update(handler, event, final=False) - yield await state._as_state_update(handler, events=None, final=True) - - # Handle regular generators. - elif inspect.isgenerator(events): - try: - while True: - yield await state._as_state_update( - handler, next(events), final=False - ) - except StopIteration as si: - # the "return" value of the generator is not available - # in the loop, we must catch StopIteration to access it - if si.value is not None: - yield await state._as_state_update( - handler, si.value, final=False - ) - yield await state._as_state_update(handler, events=None, final=True) - - # Handle regular event chains. - else: - yield await state._as_state_update(handler, events, final=True) - - # If an error occurs, throw a window alert. - except Exception as ex: - telemetry.send_error(ex, context="backend") - - event_specs = ( - prerequisites.get_and_validate_app().app.backend_exception_handler(ex) - ) - - yield await state._as_state_update( - handler, - event_specs, - final=True, - ) - def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" # Append expired computed vars to dirty_vars to trigger recalculation @@ -2187,7 +1888,9 @@ def get_value(self, key: str) -> Any: return key.__wrapped__ if isinstance(key, str): - return getattr(self, key) + if isinstance(val := getattr(self, key), MutableProxy): + return val.__wrapped__ + return val msg = f"Invalid key type: {type(key)}. Expected str." raise TypeError(msg) @@ -2508,6 +2211,25 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: return await internal_patch_linked_state(linked_token) return state_instance + @event + async def hydrate(self) -> None: + """Send the full state to the frontend to synchronize it with the backend.""" + from reflex_base.event.context import EventContext + + # Clear client storage, to respect clearing cookies + self._reset_client_storage() + + # Mark state as not hydrated (until on_loads are complete) + self.is_hydrated = False + + # Get the initial state if needed. + ctx = EventContext.get() + if ctx.emit_delta_impl is not None: + await ctx.emit_delta(delta=await _resolve_delta(self.dict())) + + # since a full dict was captured, clean any dirtiness + self._clean() + @event def set_is_hydrated(self, value: bool) -> None: """Set the hydrated state. @@ -2676,9 +2398,8 @@ def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | No return None # Fast path for navigation with no on_load events defined. self.is_hydrated = False return [ - *fix_events( - cast(list[EventSpec | EventHandler], load_events), - self.router.session.client_token, + *Event.from_event_type( + load_events, router_data=self.router_data, ), State.set_is_hydrated(True), @@ -2804,21 +2525,38 @@ class StateUpdate: """A state update sent to the frontend.""" # The state delta. - delta: Delta = dataclasses.field(default_factory=dict) + delta: DeltaMapping = dataclasses.field(default_factory=dict) # Events to be added to the event queue. events: list[Event] = dataclasses.field(default_factory=list) - # Whether this is the final state update for the event. - final: bool | None = True + # Deprecated: previously indicated whether the event processing is complete. + final: bool | None = dataclasses.field(default=None, repr=False) - def json(self) -> str: - """Convert the state update to a JSON string. + def __post_init__(self): + """Warn if the deprecated `final` attribute is supplied.""" + if self.final is not None: + console.deprecate( + feature_name="StateUpdate.final", + reason="The final attribute is no longer used.", + deprecation_version="0.9.0", + removal_version="1.0", + ) - Returns: - The state update as a JSON string. - """ - return format.json_dumps(self) + +@serializer(to=dict) +def serialize_state_update(update: StateUpdate) -> dict: + """Serialize a StateUpdate to a dictionary. + + Args: + update: The StateUpdate to serialize. + + Returns: + The serialized StateUpdate. + """ + return { + k.name: v for k in dataclasses.fields(update) if (v := getattr(update, k.name)) + } def code_uses_state_contexts(javascript_code: str) -> bool: @@ -2844,6 +2582,8 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ + from reflex_base.registry import RegistrationContext + # Reset the _app_ref of OnLoadInternalState to avoid stale references. if state is OnLoadInternalState: state._app_ref = None @@ -2855,11 +2595,13 @@ def reload_state_module( and module is not None ): state._potentially_dirty_states.remove(pd_state) - for subclass in tuple(state.class_subclasses): + reg_ctx = RegistrationContext.get() + substates = reg_ctx.get_substates(state) + for subclass in tuple(substates): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: all_base_state_classes.pop(subclass.get_full_name(), None) - state.class_subclasses.remove(subclass) + substates.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) state._var_dependencies = {} state._init_var_dependency_dicts() diff --git a/reflex/testing.py b/reflex/testing.py index b80dcd842a0..c144b22d7c6 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import contextvars import dataclasses import functools import inspect @@ -19,16 +20,18 @@ import threading import time import types -from collections.abc import AsyncIterator, Callable, Coroutine, Sequence +from collections.abc import Callable, Coroutine, Sequence +from copy import deepcopy from http.server import SimpleHTTPRequestHandler from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar import uvicorn from reflex_base.components.component import CUSTOM_COMPONENTS, CustomComponent from reflex_base.config import get_config from reflex_base.environment import environment +from reflex_base.registry import RegistrationContext from reflex_base.utils.types import ASGIApp from typing_extensions import Self @@ -39,11 +42,8 @@ import reflex.utils.prerequisites import reflex.utils.processes from reflex.experimental.memo import EXPERIMENTAL_MEMOS -from reflex.istate.manager import StateManager -from reflex.istate.manager.disk import StateManagerDisk -from reflex.istate.manager.memory import StateManagerMemory -from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, _split_substate_key, reload_state_module +from reflex.istate.shared import SharedState as SharedState # To register it. +from reflex.state import reload_state_module from reflex.utils import console, js_runtimes from reflex.utils.export import export from reflex.utils.token_manager import TokenManager @@ -118,8 +118,9 @@ class AppHarness: frontend_output_thread: threading.Thread | None = None backend_thread: threading.Thread | None = None backend: uvicorn.Server | None = None - state_manager: StateManager | None = None _frontends: list[WebDriver] = dataclasses.field(default_factory=list) + _registry_token: contextvars.Token[RegistrationContext] | None = None + _base_registration_context: ClassVar[RegistrationContext] | None = None @classmethod def create( @@ -238,7 +239,6 @@ def _get_source_from_app_source(self, app_source: Any) -> str: def _initialize_app(self): # disable telemetry reporting for tests - os.environ["REFLEX_TELEMETRY_ENABLED"] = "false" # Reset global memo registries so previous AppHarness apps do not # leak compiled component definitions into the next test app. @@ -269,6 +269,14 @@ def _initialize_app(self): with chdir(self.app_path): reflex.utils.prerequisites.initialize_frontend_dependencies() with chdir(self.app_path): + # Use a new registration context for a new app. + if AppHarness._base_registration_context is None: + # Save the initial RegistrationContext for the app if we haven't already + AppHarness._base_registration_context = ( + RegistrationContext.ensure_context() + ) + new_registration_context = deepcopy(AppHarness._base_registration_context) + self._registry_token = RegistrationContext.set(new_registration_context) # ensure config and app are reloaded when testing different app config = get_config(reload=True) # Ensure the AppHarness test does not skip State assignment due to running via pytest @@ -285,19 +293,6 @@ def _initialize_app(self): ) ) self.app_asgi = self.app_instance() - if self.app_instance and self.app_instance._state_manager is not None: - if self.app_instance._state is None: - msg = "State is not set." - raise RuntimeError(msg) - if isinstance(self.app_instance._state_manager, StateManagerRedis): - # Create our own redis connection for testing. - self.state_manager = StateManagerRedis.create(self.app_instance._state) - elif isinstance(self.app_instance._state_manager, StateManagerDisk): - self.state_manager = StateManagerDisk.create(self.app_instance._state) - if self.state_manager is None: - self.state_manager = ( - self.app_instance._state_manager if self.app_instance else None - ) def _reload_state_module(self): """Reload the rx.State module to avoid conflict when reloading.""" @@ -349,55 +344,21 @@ def _start_backend(self, port: int = 0): ) ) self.backend.shutdown = self._get_backend_shutdown_handler() + + def _run_backend(context: contextvars.Context) -> None: + if self.backend is not None: + context.run(self.backend.run) + with chdir(self.app_path): print( # noqa: T201 "Creating backend in a new thread..." ) # for pytest diagnosis - self.backend_thread = threading.Thread(target=self.backend.run) + self.backend_thread = threading.Thread( + target=_run_backend, args=(contextvars.copy_context(),) + ) self.backend_thread.start() print("Backend started.") # for pytest diagnosis #noqa: T201 - async def _reset_backend_state_manager(self): - """Reset the StateManagerRedis event loop affinity. - - This is necessary when the backend is restarted and the state manager is a - StateManagerRedis instance. - - Raises: - RuntimeError: when the state manager cannot be reset - """ - if ( - self.app_instance is not None - and self.app_instance._state_manager is not None - ): - with contextlib.suppress(RuntimeError): - await self.app_instance._state_manager.close() - if ( - self.app_instance is not None - and isinstance( - self.app_instance._state_manager, - StateManagerRedis, - ) - and self.app_instance._state is not None - ): - self.app_instance._state_manager = StateManagerRedis.create( - state=self.app_instance._state, - ) - if not isinstance(self.app_instance.state_manager, StateManagerRedis): - msg = "Failed to reset state manager." - raise RuntimeError(msg) - - # Also reset the TokenManager to avoid loop affinity issues - if ( - hasattr(self.app_instance, "event_namespace") - and self.app_instance.event_namespace is not None - and hasattr(self.app_instance.event_namespace, "_token_manager") - ): - # Import here to avoid circular imports - from reflex.utils.token_manager import TokenManager - - self.app_instance.event_namespace._token_manager = TokenManager.create() - def _start_frontend(self): # Set up the frontend. with chdir(self.app_path): @@ -504,6 +465,8 @@ def stop(self) -> None: driver.quit() self._reload_state_module() + if self._registry_token is not None: + RegistrationContext.reset(self._registry_token) if self.backend is not None: self.backend.should_exit = True @@ -716,98 +679,6 @@ def frontend( self._frontends.append(driver) return driver - async def get_state(self, token: str) -> BaseState: - """Get the state associated with the given token. - - Args: - token: The state token to look up. - - Returns: - The state instance associated with the given token - - Raises: - RuntimeError: when the app hasn't started running - """ - if self.state_manager is None: - msg = "state_manager is not set." - raise RuntimeError(msg) - if self.app_instance is not None and isinstance( - self.app_instance.state_manager, StateManagerDisk - ): - # Song and dance to convince the instance's state manager to flush - # (we can't directly await the _other_ loop's Future) - await self.app_instance.state_manager._flush_write_queue() - if isinstance(self.state_manager, StateManagerDisk): - # Force reload the latest state from disk. - client_token, _ = _split_substate_key(token) - self.state_manager.states.pop(client_token, None) - try: - return await self.state_manager.get_state(token) - finally: - await self.state_manager.close() - - async def set_state(self, token: str, **kwargs) -> None: - """Set the state associated with the given token. - - Args: - token: The state token to set. - kwargs: Attributes to set on the state. - - Raises: - RuntimeError: when the app hasn't started running - """ - if self.state_manager is None: - msg = "state_manager is not set." - raise RuntimeError(msg) - state = await self.get_state(token) - for key, value in kwargs.items(): - setattr(state, key, value) - try: - await self.state_manager.set_state(token, state) - finally: - if self.app_instance is not None and isinstance( - self.app_instance.state_manager, StateManagerDisk - ): - # Clear the token from the backend's cache so it will be reloaded. - client_token, _ = _split_substate_key(token) - self.app_instance.state_manager.states.pop(client_token, None) - await self.state_manager.close() - - @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: - """Modify the state associated with the given token and send update to frontend. - - Args: - token: The state token to modify - - Yields: - The state instance associated with the given token - - Raises: - RuntimeError: when the app hasn't started running - """ - if self.state_manager is None: - msg = "state_manager is not set." - raise RuntimeError(msg) - if self.app_instance is None: - msg = "App is not running." - raise RuntimeError(msg) - app_state_manager = self.app_instance.state_manager - if isinstance(self.state_manager, (StateManagerRedis, StateManagerDisk)): - # Temporarily replace the app's state manager with our own, since - # the redis/disk connection is on the backend_thread event loop - self.app_instance._state_manager = self.state_manager - try: - async with self.app_instance.modify_state(token) as state: - yield state - finally: - if isinstance(app_state_manager, StateManagerDisk): - # Clear the token from the cache so it will be reloaded. - client_token, _ = _split_substate_key(token) - app_state_manager.states.pop(client_token, None) - await self.state_manager.close() - self.app_instance._state_manager = app_state_manager - def token_manager(self) -> TokenManager: """Get the token manager for the app instance. @@ -878,35 +749,6 @@ def poll_for_value( raise TimeoutError(msg) return element.get_attribute("value") - def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: - """Poll app state_manager for any connected clients. - - Args: - timeout: how long to wait for client states - - Returns: - active state instances when the polling loop exited - - Raises: - RuntimeError: when the app hasn't started running - TimeoutError: when the timeout expires before any states are seen - ValueError: when the state_manager is not a memory state manager - """ - if self.app_instance is None: - msg = "App is not running." - raise RuntimeError(msg) - state_manager = self.app_instance.state_manager - if not isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - msg = "Only works with memory or disk state manager" - raise ValueError(msg) - if not self._poll_for( - target=lambda: state_manager.states, - timeout=timeout, - ): - msg = "No states were observed while polling." - raise TimeoutError(msg) - return state_manager.states - @staticmethod def poll_for_or_raise_timeout( target: Callable[[], T], @@ -1118,10 +960,17 @@ def _start_backend(self): ), ) self.backend.shutdown = self._get_backend_shutdown_handler() + + def _run_backend(context: contextvars.Context) -> None: + if self.backend is not None: + context.run(self.backend.run) + print( # noqa: T201 "Creating backend in a new thread..." ) - self.backend_thread = threading.Thread(target=self.backend.run) + self.backend_thread = threading.Thread( + target=_run_backend, args=(contextvars.copy_context(),) + ) self.backend_thread.start() print("Backend started.") # for pytest diagnosis #noqa: T201 diff --git a/reflex/utils/tasks.py b/reflex/utils/tasks.py index 8a1a8cebc04..99dd15ee4e3 100644 --- a/reflex/utils/tasks.py +++ b/reflex/utils/tasks.py @@ -3,6 +3,7 @@ import asyncio import time from collections.abc import Callable, Coroutine +from contextvars import Context from typing import Any from reflex_base.utils import console @@ -64,6 +65,7 @@ def ensure_task( exception_delay: float = 1.0, exception_limit: int = 5, exception_limit_window: float = 60.0, + task_context: Context | None = None, **kwargs: Any, ) -> asyncio.Task: """Ensure that a task is running for the given coroutine function. @@ -78,6 +80,7 @@ def ensure_task( exception_delay: The delay between retries when an exception is suppressed. exception_limit: The maximum number of suppressed exceptions within the limit window before raising. exception_limit_window: The time window in seconds for counting suppressed exceptions. + task_context: The context to use for the task. *args: The arguments to pass to the coroutine function. **kwargs: The keyword arguments to pass to the coroutine function. @@ -93,17 +96,20 @@ def ensure_task( task = getattr(owner, task_attribute, None) if task is None or task.done(): asyncio.get_running_loop() # Ensure we're in an event loop. - task = asyncio.create_task( - _run_forever( - coro_function, - *args, - suppress_exceptions=suppress_exceptions, - exception_delay=exception_delay, - exception_limit=exception_limit, - exception_limit_window=exception_limit_window, - **kwargs, - ), - name=f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}", + rf_coro = _run_forever( + coro_function, + *args, + suppress_exceptions=suppress_exceptions, + exception_delay=exception_delay, + exception_limit=exception_limit, + exception_limit_window=exception_limit_window, + **kwargs, ) + task_name = f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}" + if task_context is not None: + # Run the task in the given context (not needed after Python 3.11+ which supports passing context to create_task directly). + task = task_context.run(asyncio.create_task, rf_coro, name=task_name) + else: + task = asyncio.create_task(rf_coro, name=task_name) setattr(owner, task_attribute, task) return task diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index 1fc982ea32e..2259b26dfc8 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, ClassVar from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, StateUpdate +from reflex.state import StateUpdate from reflex.utils import console, prerequisites from reflex.utils.tasks import ensure_task @@ -248,9 +248,7 @@ async def _handle_socket_record_del( async def _subscribe_socket_record_updates(self) -> None: """Subscribe to Redis keyspace notifications for socket record updates.""" - await StateManagerRedis( - state=BaseState, redis=self.redis - )._enable_keyspace_notifications() + await StateManagerRedis(redis=self.redis)._enable_keyspace_notifications() redis_db = self.redis.get_connection_kwargs().get("db", 0) async with self.redis.pubsub() as pubsub: diff --git a/tests/benchmarks/test_event_processing.py b/tests/benchmarks/test_event_processing.py index 0acf26488fe..15acf8094d4 100644 --- a/tests/benchmarks/test_event_processing.py +++ b/tests/benchmarks/test_event_processing.py @@ -1,112 +1,102 @@ """Benchmark for the event processing pipeline. -Measures the time from calling the ``process`` function (the core of -``on_event``) to collecting all emitted ``StateUpdate`` deltas via -``contextlib.aclosing`` over the async generator. +Measures the time from enqueuing events via ``BaseStateEventProcessor`` +to collecting all emitted ``StateUpdate`` deltas, with mock emit +callbacks that record the deltas. """ import asyncio -import contextlib +import traceback +from collections.abc import Mapping +from typing import Any from unittest import mock import pytest import pytest_asyncio from pytest_codspeed import BenchmarkFixture +from reflex_base.event import Event +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor from reflex_base.utils.format import format_event_handler -from reflex.app import App, process -from reflex.event import Event from reflex.istate.manager.memory import StateManagerMemory -from reflex.state import State from .fixtures import BenchmarkState -@pytest.fixture -def app_module_mock(monkeypatch) -> mock.Mock: - """Mock the app module so state machinery can resolve the app. +@pytest_asyncio.fixture +async def event_processing_harness(): + """Set up the full event processing pipeline for benchmarking. - Args: - monkeypatch: pytest monkeypatch fixture. + Creates a ``BaseStateEventProcessor`` wired to a real + ``StateManagerMemory`` with mock emit callbacks. Events are + enqueued directly and deltas are collected via the emit callback. - Returns: - The mock for the main app module. + Yields: + An async callable that enqueues the given number of events + and waits for all expected deltas. """ - from reflex.utils import prerequisites + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]] = [] - app_module_mock = mock.Mock() - get_app_mock = mock.Mock(return_value=app_module_mock) - monkeypatch.setattr(prerequisites, "get_app", get_app_mock) - return app_module_mock + async def emit_delta_impl( # noqa: RUF029 + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + emitted_deltas.append((token, delta)) + async def emit_event_impl(token: str, *events: Event) -> None: + pass -@pytest_asyncio.fixture -async def event_processing_harness(app_module_mock: mock.Mock): - """Set up the full event processing pipeline for benchmarking. + def handle_backend_exception(ex: Exception) -> None: + formatted_exc = "\n".join(traceback.format_exception(ex)) + pytest.fail(f"Event processor raised an unexpected exception:\n{formatted_exc}") - Creates an App wired to a real StateManagerMemory. The ``process`` - function is called directly (bypassing Socket.IO) and StateUpdates - are collected and counted to verify correctness. + processor = BaseStateEventProcessor( + backend_exception_handler=handle_backend_exception, + graceful_shutdown_timeout=5, + ) + # Mock _rehydrate so the processor doesn't try to push full state + # to a non-existent frontend on the first event. + with mock.patch.object(processor, "_rehydrate", new=mock.AsyncMock()): + state_manager = StateManagerMemory() + root_context = EventContext( + token="", + state_manager=state_manager, + enqueue_impl=processor.enqueue_many, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, + ) + processor._root_context = root_context + + token = "benchmark-token" + handler_name = format_event_handler(BenchmarkState.event_handlers["increment"]) + event = Event( + name=handler_name, + router_data={ + "query": {}, + "path": "/", + }, + payload={}, + ) - Args: - app_module_mock: The mocked app module. + async def run_events(num_events: int, num_expected_deltas: int) -> None: + """Enqueue events and wait for all deltas to be emitted. - Yields: - An async callable that processes the given events and asserts - the expected number of deltas were produced. - """ - app = app_module_mock.app = App() - state_manager = StateManagerMemory(state=State) - app._state_manager = state_manager - # Disable event namespace so process() doesn't try to emit "reload" - # via Socket.IO for brand-new states. - app._event_namespace = None - - token = "benchmark-token" - sid = "benchmark-sid" - headers: dict = {} - client_ip = "127.0.0.1" - - handler_name = format_event_handler(BenchmarkState.event_handlers["increment"]) - - event = Event( - token=token, - name=handler_name, - router_data={ - "query": {}, - "path": "/", - }, - payload={}, - ) + Args: + num_events: Number of increment events to enqueue. + num_expected_deltas: How many deltas to wait for. + """ + emitted_deltas.clear() - delta_count = 0 - expected_deltas = 0 - - async def run_events(num_events: int, num_expected_deltas: int) -> None: - """Process events through the pipeline and wait for deltas. - - Args: - num_events: Number of increment events to process. - num_expected_deltas: How many StateUpdates to wait for. - """ - nonlocal delta_count, expected_deltas - delta_count = 0 - expected_deltas = num_expected_deltas - - for _ in range(num_events): - async with contextlib.aclosing( - process(app, event, sid, headers, client_ip) - ) as updates: - async for _update in updates: - delta_count += 1 - - assert delta_count == expected_deltas, ( - f"Expected {expected_deltas} StateUpdate(s), got {delta_count}" - ) + async with processor as p: + async for _ in asyncio.as_completed([ + await p.enqueue(token, event) for _ in range(num_events) + ]): + pass + assert len(emitted_deltas) == num_expected_deltas - yield run_events + yield run_events - await state_manager.close() + await state_manager.close() def test_process_event( @@ -116,8 +106,7 @@ def test_process_event( """Benchmark processing 3 increment events through the full pipeline. The first event creates fresh state (cold path), the next two reuse - the existing state (warm path). All machinery is set up outside the - benchmark; only the event processing is timed. + the existing state (warm path). Only event processing is timed. Args: event_processing_harness: The run_events async callable. @@ -126,10 +115,8 @@ def test_process_event( run_events = event_processing_harness loop = asyncio.get_event_loop() - # Each call to process() for a non-background event yields StateUpdates. - # The _process_event generator yields one update per yield/return plus a - # final update. For a simple handler like increment() with no yield, - # we get 1 StateUpdate per event = 3 total. + # Each event handler (increment) does a single state mutation with + # no yields, so we expect 1 delta per event = 3 total. @benchmark def _(): loop.run_until_complete(run_events(num_events=3, num_expected_deltas=3)) diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 1e05475b7d0..283dbe4cd12 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -335,7 +335,7 @@ def test_background_task( AppHarness.expect(lambda: counter_async_cv.text == "620", timeout=40) # all tasks should have exited and cleaned up AppHarness.expect( - lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess] + lambda: not background_task.app_instance.event_processor._tasks # pyright: ignore [reportOptionalMemberAccess] ) diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 262a49882d9..1fadc888105 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -11,10 +11,6 @@ from selenium.webdriver.firefox.webdriver import WebDriver as Firefox from selenium.webdriver.remote.webdriver import WebDriver -from reflex.istate.manager.disk import StateManagerDisk -from reflex.istate.manager.memory import StateManagerMemory -from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils @@ -22,6 +18,8 @@ def ClientSide(): """App for testing client-side state.""" + import uuid + import reflex as rx class ClientSideState(rx.State): @@ -36,6 +34,12 @@ def set_state_var(self, value: str): def set_input_value(self, value: str): self.input_value = value + @rx.event + def reset_token_no_hydrate(self): + return rx.run_script( + f"{{token = '{uuid.uuid4()}'; window.sessionStorage.setItem('token', token);}}" + ) + class ClientSideSubState(ClientSideState): # cookies with default settings c1: str = rx.Cookie() @@ -90,6 +94,11 @@ def index(): read_only=True, id="token", ), + rx.button( + "New Token - No Hydrate", + id="new_token", + on_click=ClientSideState.reset_token_no_hydrate, + ), rx.input( placeholder="state var", value=ClientSideState.state_var, @@ -350,7 +359,6 @@ def set_sub_sub(var: str, value: str): set_sub_sub("l1s", "l1s value") set_sub_sub("s1s", "s1s value") - state_name = client_side.get_full_state_name(["_client_side_state"]) sub_state_name = client_side.get_full_state_name([ "_client_side_state", "_client_side_sub_state", @@ -534,9 +542,8 @@ def set_sub_sub(var: str, value: str): assert l1s.text == "l1s value" assert s1s.text == "s1s value" - # reset the backend state to force refresh from client storage - async with client_side.modify_state(f"{token}_{state_name}") as state: - state.reset() + # set a new token to force reloading the values from client + driver.execute_script("window.sessionStorage.setItem('token', '');") driver.refresh() # wait for the backend connection to send the token (again) @@ -640,39 +647,8 @@ def set_sub_sub(var: str, value: str): assert s3.text == "s3 value" # Simulate state expiration - if isinstance(client_side.state_manager, StateManagerRedis): - await client_side.state_manager.redis.delete( - _substate_key(token, State.get_full_name()) - ) - await client_side.state_manager.redis.delete(_substate_key(token, state_name)) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_state_name) - ) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_sub_state_name) - ) - elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)): - del client_side.state_manager.states[token] - if ( - client_side.app_instance is not None - and (app_state_manager := client_side.app_instance.state_manager) is not None - and isinstance(app_state_manager, StateManagerDisk) - ): - # Purge the backend's disk manager - app_state_manager.states.pop(token, None) - app_state_manager._write_queue.pop(token, None) - og_token_expiration = app_state_manager.token_expiration - app_state_manager.token_expiration = 0 - app_state_manager._purge_expired_states() - app_state_manager.token_expiration = og_token_expiration - - # Ensure the state is gone (not hydrated) - async def poll_for_not_hydrated(): - state = await client_side.get_state(_substate_key(token or "", state_name)) - assert isinstance(state, State) - return not state.is_hydrated - - assert await AppHarness._poll_for_async(poll_for_not_hydrated) + new_token_btn = driver.find_element(By.ID, "new_token") + new_token_btn.click() # Trigger event to get a new instance of the state since the old was expired. set_sub("c1", "c1 post expire") @@ -714,41 +690,6 @@ async def poll_for_not_hydrated(): assert l1s.text == "l1s value" assert s1s.text == "s1s value" - # Get the backend state and ensure the values are still set - async def get_sub_state(): - root_state = await client_side.get_state( - _substate_key(token or "", sub_state_name) - ) - state = root_state.substates[client_side.get_state_name("_client_side_state")] - return state.substates[client_side.get_state_name("_client_side_sub_state")] - - async def poll_for_c1_set(): - sub_state = await get_sub_state() - return sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue] - - assert await AppHarness._poll_for_async(poll_for_c1_set) - sub_state = await get_sub_state() - assert sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c2 == "c2 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c3 == "" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c4 == "c4 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c5 == "c5 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c6 == "c6 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c7 == "c7 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l1 == "l1 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l2 == "l2 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l3 == "l3 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l4 == "l4 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.s1 == "s1 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.s2 == "s2 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.s3 == "s3 value" # pyright: ignore[reportAttributeAccessIssue] - sub_sub_state = sub_state.substates[ - client_side.get_state_name("_client_side_sub_sub_state") - ] - assert sub_sub_state.c1s == "c1s value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_sub_state.l1s == "l1s value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_sub_state.s1s == "s1s value" # pyright: ignore[reportAttributeAccessIssue] - # clear the cookie jar and local storage, ensure state reset to default driver.delete_all_cookies() local_storage.clear() diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index 0230fa129e4..0d674b55b5f 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -5,7 +5,6 @@ import pytest from selenium.webdriver.common.by import By -from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils @@ -32,15 +31,58 @@ def increment(self): self.count += 1 self._be = self.count # pyright: ignore [reportAttributeAccessIssue] + @rx.event + def assert_be(self, value: E): + assert self._backend_vars != self.backend_vars + assert self._be == int(value) # pyright: ignore [reportAttributeAccessIssue, reportArgumentType] + + @rx.event + def assert_be_none(self): + assert self._backend_vars == self.backend_vars + assert self._be is None # pyright: ignore [reportAttributeAccessIssue] + + @rx.event + def assert_be_int(self, value: int): + assert self._be_int == value # pyright: ignore [reportAttributeAccessIssue] + + @rx.event + def assert_be_str(self, value: str): + assert self._be_str == value # pyright: ignore [reportAttributeAccessIssue] + @classmethod def get_component(cls, *children, **props): + eid = props.get("id", "default") return rx.vstack( *children, - rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"), + rx.heading(cls.count, id=f"count-{eid}"), rx.button( "Increment", on_click=cls.increment, - id=f"button-{props.get('id', 'default')}", + id=f"button-{eid}", + ), + rx.form( + rx.input(id=f"{eid}-assert-be-value", name="be_value"), + rx.button( + "Assert _be", + id=f"{eid}-assert-be", + ), + on_submit=lambda fd: cls.assert_be(fd.to(dict)["be_value"]), # pyright: ignore [reportAttributeAccessIssue] + reset_on_submit=True, + ), + rx.button( + "Assert _be_none", + id=f"{eid}-assert-be-none", + on_click=cls.assert_be_none, + ), + rx.button( + "Assert _be_int == 0", + id=f"{eid}-assert-be-int", + on_click=cls.assert_be_int(0), + ), + rx.button( + "Assert _be_str == '42'", + id=f"{eid}-assert-be-str", + on_click=cls.assert_be_str("42"), ), **props, ) @@ -120,8 +162,7 @@ def component_state_app(tmp_path) -> Generator[AppHarness, None, None]: yield harness -@pytest.mark.asyncio -async def test_component_state_app(component_state_app: AppHarness): +def test_component_state_app(component_state_app: AppHarness): """Increment counters independently. Args: @@ -132,7 +173,6 @@ async def test_component_state_app(component_state_app: AppHarness): ss = utils.SessionStorage(driver) assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" - root_state_token = _substate_key(ss.get("token"), State) count_a = driver.find_element(By.ID, "count-a") count_b = driver.find_element(By.ID, "count-b") @@ -141,16 +181,9 @@ async def test_component_state_app(component_state_app: AppHarness): button_inc_a = driver.find_element(By.ID, "inc-a") # Check that backend vars in mixins are okay - a_state_name = driver.find_element(By.ID, "a_state_name").text - b_state_name = driver.find_element(By.ID, "b_state_name").text - root_state = await component_state_app.get_state(root_state_token) - a_state = root_state.substates[a_state_name] - b_state = root_state.substates[b_state_name] - assert a_state._backend_vars == a_state.backend_vars - assert a_state._backend_vars == b_state._backend_vars - assert a_state._backend_vars["_be"] is None - assert a_state._backend_vars["_be_int"] == 0 - assert a_state._backend_vars["_be_str"] == "42" + driver.find_element(By.ID, "a-assert-be-none").click() + driver.find_element(By.ID, "a-assert-be-int").click() + driver.find_element(By.ID, "a-assert-be-str").click() assert count_a.text == "0" @@ -163,13 +196,9 @@ async def test_component_state_app(component_state_app: AppHarness): button_inc_a.click() assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3" - root_state = await component_state_app.get_state(root_state_token) - a_state = root_state.substates[a_state_name] - b_state = root_state.substates[b_state_name] - assert a_state._backend_vars != a_state.backend_vars - assert a_state._be == a_state._backend_vars["_be"] == 3 # pyright: ignore[reportAttributeAccessIssue] - assert b_state._be is None # pyright: ignore[reportAttributeAccessIssue] - assert b_state._backend_vars["_be"] is None + driver.find_element(By.ID, "a-assert-be-value").send_keys("3") + driver.find_element(By.ID, "a-assert-be").click() + driver.find_element(By.ID, "b-assert-be-none").click() assert count_b.text == "0" @@ -179,11 +208,8 @@ async def test_component_state_app(component_state_app: AppHarness): button_b.click() assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2" - root_state = await component_state_app.get_state(root_state_token) - a_state = root_state.substates[a_state_name] - b_state = root_state.substates[b_state_name] - assert b_state._backend_vars != b_state.backend_vars - assert b_state._be == b_state._backend_vars["_be"] == 2 # pyright: ignore[reportAttributeAccessIssue] + driver.find_element(By.ID, "b-assert-be-value").send_keys("2") + driver.find_element(By.ID, "b-assert-be").click() # Check locally-defined substate style count_c = driver.find_element(By.ID, "count-c") diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index f4fb7a8d5f2..905b1594cb5 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -198,14 +198,6 @@ async def test_computed_vars( """ assert computed_vars.app_instance is not None - state_name = computed_vars.get_state_name("_state") - full_state_name = computed_vars.get_full_state_name(["_state"]) - token = f"{token}_{full_state_name}" - state = (await computed_vars.get_state(token)).substates[state_name] - assert state is not None - assert state.count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue] - assert state._count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue] - # test that backend var is not rendered count1_backend = driver.find_element(By.ID, "count1_backend") assert count1_backend @@ -257,12 +249,6 @@ async def test_computed_vars( computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0") == "1" ) - state = (await computed_vars.get_state(token)).substates[state_name] - assert state is not None - assert state.count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue] - assert count1_backend.text == "" - assert state._count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue] - assert count1_backend_.text == "" mark_dirty.click() with pytest.raises(TimeoutError): diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 25bf74de73a..6019f4b2df5 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,9 +1,13 @@ """Test case for displaying the connection banner when the websocket drops.""" +import asyncio +import contextlib import pickle -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator, Iterator import pytest +import pytest_asyncio +from redis.asyncio import Redis from reflex_base import constants from selenium.common.exceptions import NoSuchElementException from selenium.webdriver.common.by import By @@ -93,6 +97,40 @@ def connection_banner( yield harness +@contextlib.contextmanager +def browser_offline(driver: WebDriver) -> Iterator[None]: + """Context manager that takes the browser offline via CDP and restores it on exit. + + Args: + driver: Selenium WebDriver instance (must support execute_cdp_cmd). + + Yields: + None + """ + driver.execute_cdp_cmd("Network.enable", {}) + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": True, + "downloadThroughput": -1, + "uploadThroughput": -1, + "latency": 0, + }, + ) + try: + yield + finally: + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": False, + "downloadThroughput": -1, + "uploadThroughput": -1, + "latency": 0, + }, + ) + + CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]" @@ -147,12 +185,38 @@ def _assert_token(connection_banner, driver) -> str: return ss.get("token") +@pytest_asyncio.fixture +async def redis( + connection_banner: AppHarness, +) -> AsyncGenerator[Redis | None]: + """Get the Redis instance from the StateManagerRedis used in the connection_banner test. + + Args: + connection_banner: AppHarness instance. + + Yields: + A Redis instance or None if the StateManager is not Redis. + """ + from reflex.utils.prerequisites import get_redis + + redis = None + if (app := connection_banner.app_instance) is not None and isinstance( + app.state_manager, StateManagerRedis + ): + redis = get_redis() + yield redis + if redis is not None: + with contextlib.suppress(Exception, asyncio.CancelledError): + await redis.aclose() + + @pytest.mark.asyncio -async def test_connection_banner(connection_banner: AppHarness): +async def test_connection_banner(connection_banner: AppHarness, redis: Redis | None): """Test that the connection banner is displayed when the websocket drops. Args: connection_banner: AppHarness instance. + redis: Redis instance used by the app, or None if not using Redis. """ assert connection_banner.app_instance is not None assert connection_banner.backend is not None @@ -165,11 +229,9 @@ async def test_connection_banner(connection_banner: AppHarness): app_token_manager = connection_banner.token_manager() assert token in app_token_manager.token_to_sid sid_before = app_token_manager.token_to_sid[token] - if isinstance(connection_banner.state_manager, StateManagerRedis): + if redis is not None: assert isinstance(app_token_manager, RedisTokenManager) - assert await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) == pickle.dumps( + assert await redis.get(app_token_manager._get_redis_key(token)) == pickle.dumps( SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before) ) @@ -181,42 +243,28 @@ async def test_connection_banner(connection_banner: AppHarness): increment_button.click() assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" - # Start an long event before killing the backend, to mark event_processing=true + # Start a long event before blocking the network, to mark event_processing=true delay_button.click() - # Get the backend port - backend_port = connection_banner._poll_for_servers().getsockname()[1] - - # Kill the backend - connection_banner.backend.should_exit = True - if connection_banner.backend_thread is not None: - connection_banner.backend_thread.join() + with browser_offline(driver): + # Error modal should now be displayed + AppHarness.expect(lambda: has_error_modal(driver)) - # Error modal should now be displayed - AppHarness.expect(lambda: has_error_modal(driver)) + # The token association should be removed once the websocket closes on the server. + assert connection_banner._poll_for( + lambda: token not in app_token_manager.token_to_sid + ) + if redis is not None: + assert isinstance(app_token_manager, RedisTokenManager) + assert await redis.get(app_token_manager._get_redis_key(token)) is None - # The token association should have been removed when the server exited. - assert token not in app_token_manager.token_to_sid - if isinstance(connection_banner.state_manager, StateManagerRedis): - assert isinstance(app_token_manager, RedisTokenManager) + # Increment the counter while disconnected + increment_button.click() assert ( - await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) - is None + connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" ) - # Increment the counter with backend down - increment_button.click() - assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" - - # Bring the backend back up - connection_banner._start_backend(port=backend_port) - - # Create a new StateManager to avoid async loop affinity issues w/ redis - await connection_banner._reset_backend_state_manager() - - # Banner should be gone now + # Banner should be gone now (network restored on context manager exit) AppHarness.expect(lambda: not has_error_modal(driver)) # After reconnecting, the token association should be re-established. @@ -224,11 +272,9 @@ async def test_connection_banner(connection_banner: AppHarness): # Make sure the new connection has a different websocket sid. sid_after = app_token_manager.token_to_sid[token] assert sid_before != sid_after - if isinstance(connection_banner.state_manager, StateManagerRedis): + if redis is not None: assert isinstance(app_token_manager, RedisTokenManager) - assert await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) == pickle.dumps( + assert await redis.get(app_token_manager._get_redis_key(token)) == pickle.dumps( SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after) ) diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py index 2f573f192c5..7c3e7641f86 100644 --- a/tests/integration/test_dynamic_routes.py +++ b/tests/integration/test_dynamic_routes.py @@ -3,7 +3,8 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Generator +import json +from collections.abc import Generator from urllib.parse import urlsplit import pytest @@ -11,7 +12,7 @@ from reflex.testing import AppHarness, WebDriver -from .utils import poll_for_navigation +from .utils import poll_assert_event_order, poll_for_navigation def DynamicRoute(): @@ -47,6 +48,10 @@ def next_page(self) -> str: except ValueError: return "0" + @rx.var + def params(self) -> dict[str, str | list[str]]: + return self.router.page.params + def index(): return rx.fragment( rx.input( @@ -68,12 +73,14 @@ def index(): id="link_page_next", ), rx.link("missing", href="/missing", id="link_missing"), - rx.list( # pyright: ignore [reportAttributeAccessIssue] + rx.vstack( rx.foreach( DynamicState.order, # pyright: ignore [reportAttributeAccessIssue] - lambda i: rx.list_item(rx.text(i)), + rx.text, ), + id="event_order", ), + rx.text(DynamicState.params.to_string(), id="params"), ) class ArgState(rx.State): @@ -215,46 +222,10 @@ def token(dynamic_route: AppHarness, driver: WebDriver) -> str: return token -@pytest.fixture -def poll_for_order( - dynamic_route: AppHarness, token: str -) -> Callable[[list[str]], Coroutine[None, None, None]]: - """Poll for the order list to match the expected order. - - Args: - dynamic_route: harness for DynamicRoute app. - token: The token visible in the driver browser. - - Returns: - An async function that polls for the order list to match the expected order. - """ - dynamic_state_name = dynamic_route.get_state_name("_dynamic_state") - dynamic_state_full_name = dynamic_route.get_full_state_name(["_dynamic_state"]) - - async def _poll_for_order(exp_order: list[str]): - async def _backend_state(): - return await dynamic_route.get_state(f"{token}_{dynamic_state_full_name}") - - async def _check(): - return (await _backend_state()).substates[ - dynamic_state_name - ].order == exp_order # pyright: ignore[reportAttributeAccessIssue] - - await AppHarness._poll_for_async(_check, timeout=10) - assert ( - list((await _backend_state()).substates[dynamic_state_name].order) # pyright: ignore[reportAttributeAccessIssue] - == exp_order - ) - - return _poll_for_order - - -@pytest.mark.asyncio -async def test_on_load_navigate( +def test_on_load_navigate( dynamic_route: AppHarness, driver: WebDriver, token: str, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links to navigate between dynamic pages with on_load event. @@ -262,9 +233,7 @@ async def test_on_load_navigate( dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. token: The token visible in the driver browser. - poll_for_order: function that polls for the order list to match the expected order. """ - dynamic_state_full_name = dynamic_route.get_full_state_name(["_dynamic_state"]) assert dynamic_route.app_instance is not None link = driver.find_element(By.ID, "link_page_next") assert link @@ -290,7 +259,7 @@ async def test_on_load_navigate( page_id_input, exp_not_equal=str(ix - 1) ) == str(ix) assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}" - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) frontend_url = dynamic_route.frontend_url assert frontend_url @@ -300,48 +269,46 @@ async def test_on_load_navigate( exp_order += ["/page/[page_id]-10"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/page/10") - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # make sure internal nav still hydrates after redirect exp_order += ["/page/[page_id]-11"] link = driver.find_element(By.ID, "link_page_next") with poll_for_navigation(driver): link.click() - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # load same page with a query param and make sure it passes through exp_order += ["/page/[page_id]-11"] with poll_for_navigation(driver): driver.get(f"{driver.current_url}?foo=bar") - await poll_for_order(exp_order) - assert ( - await dynamic_route.get_state(f"{token}_{dynamic_state_full_name}") - ).router.page.params["foo"] == "bar" + poll_assert_event_order(driver, exp_order) + params_json = driver.find_element(By.ID, "params") + params = json.loads(params_json.text) + assert params == {"foo": "bar", "page_id": "11"} # hit a 404 and ensure we still hydrate exp_order += ["/404-no page id"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/missing") - await poll_for_order(exp_order) # browser nav should still trigger hydration exp_order += ["/page/[page_id]-11"] with poll_for_navigation(driver): driver.back() - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # next/link to a 404 and ensure we still hydrate exp_order += ["/404-no page id"] link = driver.find_element(By.ID, "link_missing") with poll_for_navigation(driver): link.click() - await poll_for_order(exp_order) # hit a page that redirects back to dynamic page exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/redirect-page/0/?foo=bar") - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # should have redirected back to page 0 assert urlsplit(driver.current_url).path.removesuffix("/") == "/page/0" @@ -349,21 +316,18 @@ async def test_on_load_navigate( exp_order += ["on-load-static"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/page/static") - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) -@pytest.mark.asyncio -async def test_on_load_navigate_non_dynamic( +def test_on_load_navigate_non_dynamic( dynamic_route: AppHarness, driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links to navigate between static pages with on_load event. Args: dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. """ assert dynamic_route.app_instance is not None link = driver.find_element(By.ID, "link_page_x") @@ -372,7 +336,7 @@ async def test_on_load_navigate_non_dynamic( with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path.removesuffix("/") == "/static/x" - await poll_for_order(["/static/x-no page id"]) + poll_assert_event_order(driver, ["/static/x-no page id"]) # go back to the index and navigate back to the static route link = driver.find_element(By.ID, "link_index") @@ -384,13 +348,13 @@ async def test_on_load_navigate_non_dynamic( with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path.removesuffix("/") == "/static/x" - await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) + poll_assert_event_order(driver, ["/static/x-no page id", "/static/x-no page id"]) for _ in range(3): link = driver.find_element(By.ID, "link_page_x") link.click() assert urlsplit(driver.current_url).path.removesuffix("/") == "/static/x" - await poll_for_order(["/static/x-no page id"] * 5) + poll_assert_event_order(driver, ["/static/x-no page id"] * 5) @pytest.mark.asyncio diff --git a/tests/integration/test_event_actions.py b/tests/integration/test_event_actions.py index 801d1c24de9..f253c306fd1 100644 --- a/tests/integration/test_event_actions.py +++ b/tests/integration/test_event_actions.py @@ -4,7 +4,7 @@ import asyncio import time -from collections.abc import Callable, Coroutine, Generator +from collections.abc import Generator import pytest from selenium.webdriver.common.by import By @@ -12,8 +12,8 @@ from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.wait import WebDriverWait -from reflex.state import BaseState from reflex.testing import AppHarness, WebDriver +from tests.integration.utils import poll_assert_event_order def TestEventAction(): @@ -158,11 +158,12 @@ def index(): 200 ).stop_propagation, ), - rx.list( # pyright: ignore [reportAttributeAccessIssue] + rx.vstack( rx.foreach( EventActionState.order, - rx.list_item, + rx.text, ), + id="event_order", ), on_click=EventActionState.on_click("outer"), # pyright: ignore [reportCallIssue] ), rx.form( @@ -245,36 +246,6 @@ def token(event_action: AppHarness, driver: WebDriver) -> str: return token -async def _backend_state(app: AppHarness, token: str) -> BaseState: - state_name = app.get_state_name("_event_action_state") - state_full_name = app.get_full_state_name(["_event_action_state"]) - return (await app.get_state(f"{token}_{state_full_name}")).substates[state_name] - - -@pytest.fixture -def poll_for_order( - event_action: AppHarness, token: str -) -> Callable[[list[str]], Coroutine[None, None, None]]: - """Poll for the order list to match the expected order. - - Args: - event_action: harness for TestEventAction app. - token: The token visible in the driver browser. - - Returns: - An async function that polls for the order list to match the expected order. - """ - - async def _poll_for_order(exp_order: list[str]): - async def _check(): - return (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue] - - await AppHarness._poll_for_async(_check) - assert (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue] - - return _poll_for_order - - @pytest.mark.parametrize( ("element_id", "exp_order"), [ @@ -303,7 +274,6 @@ async def _check(): @pytest.mark.asyncio async def test_event_actions( driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], element_id: str, exp_order: list[str], ): @@ -311,7 +281,6 @@ async def test_event_actions( Args: driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. element_id: The id of the element to click. exp_order: The expected order of events. """ @@ -324,7 +293,7 @@ async def test_event_actions( if "on_click:outer" not in exp_order: # really make sure the outer event is not fired await asyncio.sleep(0.5) - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) if element_id.startswith("link") and "prevent-default" not in element_id: assert driver.current_url != prev_url @@ -332,8 +301,7 @@ async def test_event_actions( assert driver.current_url == prev_url -@pytest.mark.asyncio -async def test_event_actions_throttle_debounce( +def test_event_actions_throttle_debounce( event_action: AppHarness, driver: WebDriver, token: str, @@ -358,14 +326,16 @@ async def test_event_actions_throttle_debounce( btn_debounce.click() # Wait until the debounce event shows up - async def _debounce_received(): - state = await _backend_state(event_action, token) - return state.order and state.order[-1] == "on_click_debounce" # pyright: ignore[reportAttributeAccessIssue] + def _debounce_received(): + order = driver.find_elements(By.XPATH, '//*[@id="event_order"]/p') + return len(order) and order[-1].text == "on_click_debounce" - await AppHarness._poll_for_async(_debounce_received) + AppHarness._poll_for(_debounce_received) # This test is inherently racy, so ensure the `on_click_throttle` event is fired approximately the expected number of times. - final_event_order = (await _backend_state(event_action, token)).order # pyright: ignore[reportAttributeAccessIssue] + final_event_order = [ + elem.text for elem in driver.find_elements(By.XPATH, '//*[@id="event_order"]/p') + ] n_on_click_throttle_received = final_event_order.count("on_click_throttle") print( f"Expected ~{exp_events} on_click_throttle events, received {n_on_click_throttle_received}" @@ -377,16 +347,13 @@ async def _debounce_received(): @pytest.mark.usefixtures("token") -@pytest.mark.asyncio -async def test_event_actions_dialog_form_in_form( +def test_event_actions_dialog_form_in_form( driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links and buttons and assert on fired events. Args: driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. """ open_dialog_id = "btn-dialog" submit_button_id = "btn-submit" @@ -400,4 +367,4 @@ async def test_event_actions_dialog_form_in_form( btn_no_events = wait.until(EC.element_to_be_clickable((By.ID, "btn-no-events"))) btn_no_events.location_once_scrolled_into_view btn_no_events.click() - await poll_for_order(["on_submit", "on_click:outer"]) + poll_assert_event_order(driver, ["on_submit", "on_click:outer"]) diff --git a/tests/integration/test_event_chain.py b/tests/integration/test_event_chain.py index 289f12cbe80..942b70994e6 100644 --- a/tests/integration/test_event_chain.py +++ b/tests/integration/test_event_chain.py @@ -9,6 +9,10 @@ from selenium.webdriver.common.by import By from reflex.testing import AppHarness, WebDriver +from tests.integration.utils import ( + poll_assert_event_order, + poll_assert_relative_event_order, +) MANY_EVENTS = 50 @@ -146,14 +150,20 @@ def click_yield_interim_value(self): app = rx.App() - token_input = rx.input( - value=State.router.session.client_token, is_read_only=True, id="token" + common_elements = rx.vstack( + rx.input( + value=State.router.session.client_token, is_read_only=True, id="token" + ), + rx.vstack( + rx.foreach(State.event_order, lambda x: rx.text(x)), id="event_order" + ), + rx.input(value=State.is_hydrated, is_read_only=True, id="is_hydrated"), ) @app.add_page def index(): return rx.fragment( - token_input, + common_elements, rx.input(value=State.interim_value, is_read_only=True, id="interim_value"), rx.button( "Return Event", @@ -225,13 +235,13 @@ def index(): def on_load_return_chain(): return rx.fragment( rx.text("return"), - token_input, + common_elements, ) def on_load_yield_chain(): return rx.fragment( rx.text("yield"), - token_input, + common_elements, ) def on_mount_return_chain(): @@ -241,7 +251,7 @@ def on_mount_return_chain(): on_mount=State.on_load_return_chain, on_unmount=lambda: State.event_arg("unmount"), ), - token_input, + common_elements, rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), ) @@ -255,7 +265,7 @@ def on_mount_yield_chain(): ], on_unmount=State.event_no_args, ), - token_input, + common_elements, rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), ) @@ -315,6 +325,7 @@ def event_chain_strict(tmp_path_factory) -> Generator[AppHarness, None, None]: with AppHarness.create( root=tmp_path_factory.mktemp("event_chain_strict"), app_source=EventChain, + app_name="event_chain_strict", ) as harness: yield harness @@ -440,8 +451,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str: ), ], ) -@pytest.mark.asyncio -async def test_event_chain_click( +def test_event_chain_click( event_chain: AppHarness, driver: WebDriver, button_id: str, @@ -455,19 +465,11 @@ async def test_event_chain_click( button_id: the ID of the button to click exp_event_order: the expected events recorded in the State """ - token = assert_token(event_chain, driver) - state_name = event_chain.get_state_name("_state") + assert_token(event_chain, driver) btn = driver.find_element(By.ID, button_id) btn.click() - async def _has_all_events(): - return len( - (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - ) == len(exp_event_order) - - await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - assert event_order == exp_event_order + poll_assert_event_order(driver, exp_event_order) @pytest.mark.parametrize( @@ -493,8 +495,7 @@ async def _has_all_events(): ), ], ) -@pytest.mark.asyncio -async def test_event_chain_on_load( +def test_event_chain_on_load( event_chain: AppHarness, driver: WebDriver, uri: str, @@ -510,52 +511,67 @@ async def test_event_chain_on_load( """ assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url.removesuffix("/") + uri) - token = assert_token(event_chain, driver) - state_name = event_chain.get_state_name("_state") - - async def _has_all_events(): - return len( - (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - ) == len(exp_event_order) + assert_token(event_chain, driver) - await AppHarness._poll_for_async(_has_all_events) - backend_state = (await event_chain.get_state(token)).substates[state_name] - assert backend_state.event_order == exp_event_order # pyright: ignore[reportAttributeAccessIssue] - assert backend_state.is_hydrated is True # pyright: ignore[reportAttributeAccessIssue] + poll_assert_event_order(driver, exp_event_order) + assert ( + event_chain.poll_for_value( + driver.find_element(By.ID, "is_hydrated"), exp_not_equal="false" + ) + == "true" + ) @pytest.mark.parametrize( - ("uri", "exp_event_order"), + ("uri", "expected_counts", "ordering_rules"), [ ( "/on-mount-return-chain", + { + "on_load_return_chain": 1, + "event_arg:1": 1, + "event_arg:2": 1, + "event_arg:3": 1, + "event_arg:unmount": 1, + }, [ - "on_load_return_chain", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:unmount", + # on_load before chain and unmount + (("on_load_return_chain", 0), ("event_arg:1", 0)), + (("on_load_return_chain", 0), ("event_arg:unmount", 0)), + # Chain in order + (("event_arg:1", 0), ("event_arg:2", 0)), + (("event_arg:2", 0), ("event_arg:3", 0)), ], ), ( "/on-mount-yield-chain", + { + "on_load_yield_chain": 1, + "event_arg:4": 1, + "event_arg:5": 1, + "event_arg:6": 1, + "event_arg:mount": 1, + "event_no_args": 1, + }, [ - "on_load_yield_chain", - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_no_args", + # on_load before chain and mount + (("on_load_yield_chain", 0), ("event_arg:4", 0)), + (("on_load_yield_chain", 0), ("event_arg:mount", 0)), + # Chain in order + (("event_arg:4", 0), ("event_arg:5", 0)), + (("event_arg:5", 0), ("event_arg:6", 0)), + # mount before event_no_args + (("event_arg:mount", 0), ("event_no_args", 0)), ], ), ], ) -@pytest.mark.asyncio -async def test_event_chain_on_mount( +def test_event_chain_on_mount( event_chain: AppHarness, driver: WebDriver, uri: str, - exp_event_order: list[str], + expected_counts: dict[str, int], + ordering_rules: list, ): """Load the URI, assert that the events are handled in the correct order. @@ -568,7 +584,8 @@ async def test_event_chain_on_mount( event_chain: AppHarness for the event_chain app driver: selenium WebDriver open to the app uri: the page to load - exp_event_order: the expected events recorded in the State + expected_counts: mapping of event name to expected occurrence count + ordering_rules: relative ordering constraints between event occurrences """ assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url.removesuffix("/") + uri) @@ -576,63 +593,80 @@ async def test_event_chain_on_mount( unmount_button = AppHarness.poll_for_or_raise_timeout( lambda: driver.find_element(By.ID, "unmount") ) - token = assert_token(event_chain, driver) - state_name = event_chain.get_state_name("_state") + assert_token(event_chain, driver) unmount_button.click() - async def _has_all_events(): - return len( - (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - ) == len(exp_event_order) - - await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - assert list(event_order) == exp_event_order + poll_assert_relative_event_order(driver, expected_counts, ordering_rules) @pytest.mark.parametrize( - ("uri", "exp_event_order"), + ("uri", "expected_counts", "ordering_rules"), [ ( "/on-mount-return-chain", + { + "on_load_return_chain": 2, + "event_arg:1": 2, + "event_arg:2": 2, + "event_arg:3": 2, + "event_arg:unmount": 2, + }, [ - "on_load_return_chain", - "event_arg:unmount", - "on_load_return_chain", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:unmount", + # First on_load before first chain and first unmount + (("on_load_return_chain", 0), ("event_arg:1", 0)), + (("on_load_return_chain", 0), ("event_arg:unmount", 0)), + # First chain in order + (("event_arg:1", 0), ("event_arg:2", 0)), + (("event_arg:2", 0), ("event_arg:3", 0)), + # First unmount before second on_load + (("event_arg:unmount", 0), ("on_load_return_chain", 1)), + # Second on_load before second chain and second unmount + (("on_load_return_chain", 1), ("event_arg:1", 1)), + (("on_load_return_chain", 1), ("event_arg:unmount", 1)), + # Second chain in order + (("event_arg:1", 1), ("event_arg:2", 1)), + (("event_arg:2", 1), ("event_arg:3", 1)), ], ), ( "/on-mount-yield-chain", + { + "on_load_yield_chain": 2, + "event_arg:4": 2, + "event_arg:5": 2, + "event_arg:6": 2, + "event_arg:mount": 2, + "event_no_args": 2, + }, [ - "on_load_yield_chain", - "event_arg:mount", - "event_no_args", - "on_load_yield_chain", - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_no_args", + # First on_load before first chain and first mount + (("on_load_yield_chain", 0), ("event_arg:4", 0)), + (("on_load_yield_chain", 0), ("event_arg:mount", 0)), + # First chain in order + (("event_arg:4", 0), ("event_arg:5", 0)), + (("event_arg:5", 0), ("event_arg:6", 0)), + # First mount before first event_no_args + (("event_arg:mount", 0), ("event_no_args", 0)), + # First event_no_args before second on_load + (("event_no_args", 0), ("on_load_yield_chain", 1)), + # Second on_load before second chain and second mount + (("on_load_yield_chain", 1), ("event_arg:4", 1)), + (("on_load_yield_chain", 1), ("event_arg:mount", 1)), + # Second chain in order + (("event_arg:4", 1), ("event_arg:5", 1)), + (("event_arg:5", 1), ("event_arg:6", 1)), + # Second mount before second event_no_args + (("event_arg:mount", 1), ("event_no_args", 1)), ], ), ], ) -@pytest.mark.asyncio -async def test_event_chain_on_mount_strict( +def test_event_chain_on_mount_strict( event_chain_strict: AppHarness, driver_strict: WebDriver, uri: str, - exp_event_order: list[str], + expected_counts: dict[str, int], + ordering_rules: list, ): """Run the test_event_chain_on_mount test with strict mode enabled. @@ -640,13 +674,15 @@ async def test_event_chain_on_mount_strict( event_chain_strict: AppHarness for the event_chain app with strict mode enabled driver_strict: selenium WebDriver open to the app with strict mode enabled uri: the page to load - exp_event_order: the expected events recorded in the State + expected_counts: mapping of event name to expected occurrence count + ordering_rules: relative ordering constraints between event occurrences """ - await test_event_chain_on_mount( + test_event_chain_on_mount( event_chain=event_chain_strict, driver=driver_strict, uri=uri, - exp_event_order=exp_event_order, + expected_counts=expected_counts, + ordering_rules=ordering_rules, ) diff --git a/tests/integration/test_form_submit.py b/tests/integration/test_form_submit.py index fde4b6746b8..bee75664ea3 100644 --- a/tests/integration/test_form_submit.py +++ b/tests/integration/test_form_submit.py @@ -2,6 +2,7 @@ import asyncio import functools +import json from collections.abc import Generator import pytest @@ -21,7 +22,7 @@ def FormSubmit(form_component): import reflex as rx class FormState(rx.State): - form_data: dict = {} + form_data: rx.Field[dict] = rx.field(default_factory=dict) var_options: list[str] = ["option3", "option4"] @@ -65,6 +66,7 @@ def index(): on_submit=FormState.form_submit, custom_attrs={"action": "/invalid"}, ), + rx.text(FormState.form_data.to_string(), id="form-data"), rx.spacer(), height="100vh", ) @@ -79,7 +81,7 @@ def FormSubmitName(form_component): import reflex as rx class FormState(rx.State): - form_data: dict = {} + form_data: rx.Field[dict] = rx.field(default_factory=dict) val: str = "foo" options: list[str] = ["option1", "option2"] @@ -122,6 +124,7 @@ def index(): on_submit=FormState.form_submit, custom_attrs={"action": "/invalid"}, ), + rx.text(FormState.form_data.to_string(), id="form-data"), rx.spacer(), height="100vh", ) @@ -225,20 +228,12 @@ async def test_submit(driver, form_submit: AppHarness): submit_input = driver.find_element(By.CLASS_NAME, "rt-Button") submit_input.click() - state_name = form_submit.get_state_name("_form_state") - full_state_name = form_submit.get_full_state_name(["_form_state"]) - - async def get_form_data(): - return ( - (await form_submit.get_state(f"{token}_{full_state_name}")) - .substates[state_name] - .form_data # pyright: ignore[reportAttributeAccessIssue] - ) - # wait for the form data to arrive at the backend - form_data = await AppHarness._poll_for_async(get_form_data) + form_submit.poll_for_content( + driver.find_element(By.ID, "form-data"), exp_not_equal="{}" + ) + form_data = json.loads(driver.find_element(By.ID, "form-data").text) assert isinstance(form_data, dict) - form_data = format.collect_form_dict_names(form_data) print(form_data) diff --git a/tests/integration/test_input.py b/tests/integration/test_input.py index e4859a3a9de..fda2752c289 100644 --- a/tests/integration/test_input.py +++ b/tests/integration/test_input.py @@ -20,6 +20,10 @@ class State(rx.State): def set_text(self, text: str): self.text = text + @rx.event + def do_clear(self): + self.text = "" + app = rx.App() @app.add_page @@ -28,6 +32,11 @@ def index(): rx.input( value=State.router.session.client_token, is_read_only=True, id="token" ), + rx.button( + "Clear State", + on_click=State.do_clear, + id="clear-backend", + ), rx.input( id="debounce_input_input", on_change=State.set_text, @@ -72,8 +81,7 @@ def fully_controlled_input(tmp_path) -> Generator[AppHarness, None, None]: yield harness -@pytest.mark.asyncio -async def test_fully_controlled_input(fully_controlled_input: AppHarness): +def test_fully_controlled_input(fully_controlled_input: AppHarness): """Type text after moving cursor. Update text on backend. Args: @@ -91,13 +99,6 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): token = fully_controlled_input.poll_for_value(token_input) assert token - state_name = fully_controlled_input.get_state_name("_state") - full_state_name = fully_controlled_input.get_full_state_name(["_state"]) - - async def get_state_text(): - state = await fully_controlled_input.get_state(f"{token}_{full_state_name}") - return state.substates[state_name].text # pyright: ignore[reportAttributeAccessIssue] - # ensure defaults are set correctly assert ( fully_controlled_input.poll_for_value( @@ -142,15 +143,10 @@ async def get_state_text(): lambda: fully_controlled_input.poll_for_value(value_input) == "ifoonitial" ) assert debounce_input.get_attribute("value") == "ifoonitial" - assert await get_state_text() == "ifoonitial" assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial" # clear the input on the backend - async with fully_controlled_input.modify_state( - f"{token}_{full_state_name}" - ) as state: - state.substates[state_name].text = "" - assert await get_state_text() == "" + driver.find_element(By.ID, "clear-backend").click() assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -166,7 +162,6 @@ async def get_state_text(): ) ) assert debounce_input.get_attribute("value") == "getting testing done" - assert await get_state_text() == "getting testing done" assert ( fully_controlled_input.poll_for_value(plain_value_input) == "getting testing done" @@ -181,7 +176,6 @@ async def get_state_text(): ) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert await get_state_text() == "overwrite the state" assert ( fully_controlled_input.poll_for_value(plain_value_input) == "overwrite the state" diff --git a/tests/integration/test_memory_state_manager_expiration.py b/tests/integration/test_memory_state_manager_expiration.py index a6feda0f0ef..f4d3e88d7dc 100644 --- a/tests/integration/test_memory_state_manager_expiration.py +++ b/tests/integration/test_memory_state_manager_expiration.py @@ -57,7 +57,9 @@ def memory_expiration_app( app_name=f"memory_expiration_{app_harness_env.__name__.lower()}", app_source=MemoryExpirationApp, ) as harness: - assert isinstance(harness.state_manager, StateManagerMemory) + assert isinstance( + getattr(harness.app_instance, "state_manager", None), StateManagerMemory + ) yield harness diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index 1f21f22187c..4c7f8e26997 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -3,10 +3,11 @@ from __future__ import annotations import asyncio +import json import time from collections.abc import Generator from pathlib import Path -from typing import Any, cast +from typing import Any from urllib.parse import urlsplit import pytest @@ -21,12 +22,14 @@ def UploadFile(): """App for testing dynamic routes.""" + import shutil + import reflex as rx LARGE_DATA = "DUMMY" * 1024 * 512 class UploadState(rx.State): - _file_data: dict[str, str] = {} + upload_done: rx.Field[bool] = rx.field(False) event_order: rx.Field[list[str]] = rx.field([]) progress_dicts: rx.Field[list[dict]] = rx.field([]) stream_progress_dicts: rx.Field[list[dict]] = rx.field([]) @@ -38,29 +41,42 @@ class UploadState(rx.State): @rx.event async def handle_upload(self, files: list[rx.UploadFile]): + self.upload_done = False for file in files: upload_data = await file.read() - self._file_data[file.name or ""] = upload_data.decode("utf-8") + if not file.name: + continue + local_file = rx.get_upload_dir() / file.name + local_file.parent.mkdir(parents=True, exist_ok=True) + local_file.write_bytes(upload_data) + self.upload_done = True @rx.event async def handle_upload_secondary(self, files: list[rx.UploadFile]): + self.upload_done = False for file in files: upload_data = await file.read() - self._file_data[file.name or ""] = upload_data.decode("utf-8") + if not file.name: + continue + local_file = rx.get_upload_dir() / file.name + local_file.parent.mkdir(parents=True, exist_ok=True) + local_file.write_bytes(upload_data) self.large_data = LARGE_DATA yield UploadState.chain_event @rx.event def upload_progress(self, progress): assert progress - self.event_order.append("upload_progress") + print(self.event_order) self.progress_dicts.append(progress) @rx.event def chain_event(self): assert self.large_data == LARGE_DATA self.large_data = "" + self.upload_done = True self.event_order.append("chain_event") + print(self.event_order) @rx.event def stream_upload_progress(self, progress): @@ -69,17 +85,23 @@ def stream_upload_progress(self, progress): @rx.event async def handle_upload_tertiary(self, files: list[rx.UploadFile]): + self.upload_done = False for file in files: (rx.get_upload_dir() / (file.name or "INVALID")).write_bytes( await file.read() ) + self.upload_done = True @rx.event async def handle_upload_quaternary(self, files: list[rx.UploadFile]): + self.upload_done = False self.quaternary_names = [file.name for file in files if file.name] + self.upload_done = True @rx.event(background=True) async def handle_upload_stream(self, chunk_iter: rx.UploadChunkIterator): + async with self: + self.upload_done = False upload_dir = rx.get_upload_dir() / "streaming" file_handles: dict[str, Any] = {} @@ -106,11 +128,17 @@ async def handle_upload_stream(self, chunk_iter: rx.UploadChunkIterator): async with self: self.stream_completed_files = sorted(file_handles) + self.upload_done = True @rx.event def do_download(self): return rx.download(rx.get_upload_url("test.txt")) + @rx.event + def clear_uploads(self): + shutil.rmtree(rx.get_upload_dir(), ignore_errors=True) + self.reset() + def index(): return rx.vstack( rx.input( @@ -118,6 +146,16 @@ def index(): read_only=True, id="token", ), + rx.input( + value=UploadState.upload_done.to_string(), + read_only=True, + id="upload_done", + ), + rx.button( + "Clear Uploaded Files", + id="clear_uploads", + on_click=UploadState.clear_uploads, + ), rx.heading("Default Upload"), rx.upload.root( rx.vstack( @@ -177,7 +215,8 @@ def index(): rx.foreach( UploadState.progress_dicts, lambda d: rx.text(d.to_string()), - ) + ), + id="progress_dicts", ), rx.button( "Cancel", @@ -265,6 +304,13 @@ def index(): UploadState.stream_completed_files.to_string(), id="stream_completed_files", ), + rx.vstack( + rx.foreach( + UploadState.stream_progress_dicts, + lambda d: rx.text(d.to_string()), + ), + id="stream_progress_dicts", + ), rx.text(UploadState.event_order.to_string(), id="event-order"), ) @@ -282,11 +328,18 @@ def upload_file(tmp_path_factory) -> Generator[AppHarness, None, None]: Yields: running AppHarness instance """ - with AppHarness.create( - root=tmp_path_factory.mktemp("upload_file"), - app_source=UploadFile, - ) as harness: - yield harness + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setenv( + "REFLEX_UPLOADED_FILES_DIR", str(tmp_path_factory.mktemp("uploaded_files")) + ) + try: + with AppHarness.create( + root=tmp_path_factory.mktemp("upload_file"), + app_source=UploadFile, + ) as harness: + yield harness + finally: + monkeypatch.undo() @pytest.fixture @@ -327,8 +380,7 @@ def poll_for_token(driver: WebDriver, upload_file: AppHarness) -> str: @pytest.mark.parametrize("secondary", [False, True]) -@pytest.mark.asyncio -async def test_upload_file( +def test_upload_file( tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool ): """Submit a file upload and check that it arrived on the backend. @@ -340,10 +392,9 @@ async def test_upload_file( secondary: whether to use the secondary upload form """ assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - full_state_name = upload_file.get_full_state_name(["_upload_state"]) - state_name = upload_file.get_state_name("_upload_state") - substate_token = f"{token}_{full_state_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() suffix = "_secondary" if secondary else "" @@ -366,27 +417,20 @@ async def test_upload_file( selected_files = driver.find_element(By.ID, f"selected_files{suffix}") assert Path(selected_files.text).name == Path(exp_name).name + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" + if secondary: event_order_displayed = driver.find_element(By.ID, "event-order") AppHarness.expect(lambda: "chain_event" in event_order_displayed.text) - - state = await upload_file.get_state(substate_token) - # only the secondary form tracks progress and chain events - assert state.substates[state_name].event_order.count("upload_progress") == 1 # pyright: ignore[reportAttributeAccessIssue] - assert state.substates[state_name].event_order.count("chain_event") == 1 # pyright: ignore[reportAttributeAccessIssue] + progress_dicts = driver.find_elements(By.XPATH, "//*[@id='progress_dicts']/p") + assert len(progress_dicts) > 0 + assert json.loads(progress_dicts[-1].text)["progress"] == 1 # look up the backend state and assert on uploaded contents - async def get_file_data(): - return ( - (await upload_file.get_state(substate_token)) - .substates[state_name] - ._file_data # pyright: ignore[reportAttributeAccessIssue] - ) - - file_data = await AppHarness._poll_for_async(get_file_data) - assert isinstance(file_data, dict) - normalized_file_data = {Path(k).name: v for k, v in file_data.items()} - assert normalized_file_data[Path(exp_name).name] == exp_contents + actual_contents = (rx.get_upload_dir() / exp_name).read_text() + assert actual_contents == exp_contents @pytest.mark.asyncio @@ -399,10 +443,9 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): driver: WebDriver instance. """ assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - full_state_name = upload_file.get_full_state_name(["_upload_state"]) - state_name = upload_file.get_state_name("_upload_state") - substate_token = f"{token}_{full_state_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_element(By.XPATH, "//input[@type='file']") assert upload_box @@ -430,19 +473,13 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # do the upload upload_button.click() - # look up the backend state and assert on uploaded contents - async def get_file_data(): - return ( - (await upload_file.get_state(substate_token)) - .substates[state_name] - ._file_data # pyright: ignore[reportAttributeAccessIssue] - ) + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" - file_data = await AppHarness._poll_for_async(get_file_data) - assert isinstance(file_data, dict) - normalized_file_data = {Path(k).name: v for k, v in file_data.items()} - for exp_name, exp_contents in exp_files.items(): - assert normalized_file_data[Path(exp_name).name] == exp_contents + for exp_name, exp_content in exp_files.items(): + actual_contents = (rx.get_upload_dir() / exp_name).read_text() + assert actual_contents == exp_content @pytest.mark.parametrize("secondary", [False, True]) @@ -459,6 +496,8 @@ def test_clear_files( """ assert upload_file.app_instance is not None poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() suffix = "_secondary" if secondary else "" @@ -520,10 +559,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive "latency": 200, # 200ms }, ) - token = poll_for_token(driver, upload_file) - state_name = upload_file.get_state_name("_upload_state") - state_full_name = upload_file.get_full_state_name(["_upload_state"]) - substate_token = f"{token}_{state_full_name}" + poll_for_token(driver, upload_file) upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1] upload_button = driver.find_element(By.ID, "upload_button_secondary") @@ -543,23 +579,11 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # Wait a bit for the upload to get cancelled. await asyncio.sleep(12) - # Get interim progress dicts saved in the on_upload_progress handler. - async def _progress_dicts(): - state = await upload_file.get_state(substate_token) - return state.substates[state_name].progress_dicts # pyright: ignore[reportAttributeAccessIssue] - - # We should have _some_ progress - assert await AppHarness._poll_for_async(_progress_dicts) - # But there should never be a final progress record for a cancelled upload. - for p in await _progress_dicts(): - assert p["progress"] != 1 + for p in driver.find_elements(By.XPATH, "//*[@id='progress_dicts']/p"): + assert json.loads(p.text)["progress"] != 1 - state = await upload_file.get_state(substate_token) - file_data = state.substates[state_name]._file_data # pyright: ignore[reportAttributeAccessIssue] - assert isinstance(file_data, dict) - normalized_file_data = {Path(k).name: v for k, v in file_data.items()} - assert Path(exp_name).name not in normalized_file_data + assert not (rx.get_upload_dir() / exp_name).exists() target_file.unlink() @@ -568,10 +592,9 @@ async def _progress_dicts(): async def test_upload_chunk_file(tmp_path, upload_file: AppHarness, driver: WebDriver): """Submit a streaming upload and check that chunks are processed incrementally.""" assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - state_name = upload_file.get_state_name("_upload_state") - state_full_name = upload_file.get_full_state_name(["_upload_state"]) - substate_token = f"{token}_{state_full_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] upload_button = driver.find_element(By.ID, "upload_button_streaming") @@ -598,28 +621,6 @@ async def test_upload_chunk_file(tmp_path, upload_file: AppHarness, driver: WebD AppHarness.expect(lambda: "stream1.txt" in chunk_records_display.text) - async def _stream_completed(): - state = await upload_file.get_state(substate_token) - return ( - len( - state.substates[state_name].stream_completed_files # pyright: ignore[reportAttributeAccessIssue] - ) - == 2 - ) - - await AppHarness._poll_for_async(_stream_completed) - - state = await upload_file.get_state(substate_token) - substate = cast(Any, state.substates[state_name]) - chunk_records = substate.stream_chunk_records - - assert len(chunk_records) > 2 - assert {Path(record.split(":")[0]).name for record in chunk_records} == { - "stream1.txt", - "stream2.txt", - } - assert substate.stream_completed_files == ["stream1.txt", "stream2.txt"] - AppHarness.expect( lambda: ( "stream1.txt" in completed_files_display.text @@ -627,6 +628,10 @@ async def _stream_completed(): ) ) + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" + for exp_name, exp_contents in exp_files.items(): assert ( rx.get_upload_dir() / "streaming" / exp_name @@ -651,10 +656,7 @@ async def test_cancel_upload_chunk( "latency": 200, # 200ms }, ) - token = poll_for_token(driver, upload_file) - state_name = upload_file.get_state_name("_upload_state") - state_full_name = upload_file.get_full_state_name(["_upload_state"]) - substate_token = f"{token}_{state_full_name}" + poll_for_token(driver, upload_file) upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] upload_button = driver.find_element(By.ID, "upload_button_streaming") @@ -668,26 +670,16 @@ async def test_cancel_upload_chunk( upload_box.send_keys(str(target_file)) upload_button.click() - await asyncio.sleep(1) + await asyncio.sleep(2) cancel_button.click() - await asyncio.sleep(12) - - async def _stream_progress_dicts(): - state = await upload_file.get_state(substate_token) - return ( - state.substates[state_name].stream_progress_dicts # pyright: ignore[reportAttributeAccessIssue] - ) - - assert await AppHarness._poll_for_async(_stream_progress_dicts) + await asyncio.sleep(11) - for progress in await _stream_progress_dicts(): - assert progress["progress"] != 1 + # But there should never be a final progress record for a cancelled upload. + for p in driver.find_elements(By.XPATH, "//*[@id='stream_progress_dicts']/p"): + assert json.loads(p.text)["progress"] != 1 - state = await upload_file.get_state(substate_token) - substate = cast(Any, state.substates[state_name]) - assert substate.stream_completed_files == [] - assert substate.stream_chunk_records + assert not (rx.get_upload_dir() / exp_name).exists() partial_path = rx.get_upload_dir() / "streaming" / exp_name assert partial_path.exists() @@ -715,6 +707,8 @@ def test_upload_download_file( """ assert upload_file.app_instance is not None poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[2] assert upload_box @@ -749,8 +743,7 @@ def test_upload_download_file( assert driver.find_element(by=By.TAG_NAME, value="body").text == exp_contents -@pytest.mark.asyncio -async def test_on_drop( +def test_on_drop( tmp_path, upload_file: AppHarness, driver: WebDriver, @@ -763,10 +756,9 @@ async def test_on_drop( driver: WebDriver instance. """ assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - full_state_name = upload_file.get_full_state_name(["_upload_state"]) - state_name = upload_file.get_state_name("_upload_state") - substate_token = f"{token}_{full_state_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[ 3 @@ -781,16 +773,18 @@ async def test_on_drop( # Simulate file drop by directly setting the file input upload_box.send_keys(str(target_file)) - # Wait for the on_drop event to be processed - await asyncio.sleep(0.5) + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" - async def exp_name_in_quaternary(): - state = await upload_file.get_state(substate_token) - return exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue] + def exp_name_in_quaternary(): + quaternary_files = driver.find_element(By.ID, "quaternary_files").text + if quaternary_files: + files = json.loads(quaternary_files) + return exp_name in files + return False # Poll until the file names appear in the display - await AppHarness._poll_for_async(exp_name_in_quaternary) + AppHarness._poll_for(exp_name_in_quaternary) - # Verify through state that the file names were captured correctly - state = await upload_file.get_state(substate_token) - assert exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue] + assert exp_name_in_quaternary() diff --git a/tests/integration/utils.py b/tests/integration/utils.py index d6b705551f9..dd70cb4ad71 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -2,9 +2,10 @@ from __future__ import annotations -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Sequence from contextlib import contextmanager +from selenium.webdriver.common.by import By from selenium.webdriver.remote.webdriver import WebDriver from reflex.testing import AppHarness @@ -33,6 +34,160 @@ def poll_for_navigation( AppHarness.expect(lambda: prev_url != driver.current_url, timeout=timeout) +def n_expected_events(exp_event_order: Sequence[str | set[str]]) -> int: + """Calculate the number of expected events, accounting for sets in the expected order. + + Args: + exp_event_order: the expected events recorded in the State, where some entries may be sets of events that can occur in any order. + + Returns: + The total number of expected events. + """ + return sum( + len(events) if isinstance(events, set) else 1 for events in exp_event_order + ) + + +def assert_event_order( + actual_event_order: list[str], exp_event_order: Sequence[str | set[str]] +) -> None: + """Verify that the actual event order matches the expected event order, accounting for sets in the expected order. + + Args: + actual_event_order: the actual events recorded in the State. + exp_event_order: the expected events recorded in the State, where some entries may be sets of events that can occur in any order. + + Raises: + AssertionError: if the actual event order does not match the expected event order. + """ + actual_idx = 0 + for expected in exp_event_order: + if isinstance(expected, str): + assert actual_event_order[actual_idx] == expected, ( + f"Expected event '{expected}' at position {actual_idx}, but got '{actual_event_order[actual_idx]}'." + ) + actual_idx += 1 + else: # expected is a set of events that can occur in any order + expected_events = set(expected) + actual_events = set( + actual_event_order[actual_idx : actual_idx + len(expected_events)] + ) + assert actual_events == expected_events, ( + f"Expected events {expected_events} at positions {actual_idx} to {actual_idx + len(expected_events) - 1}, but got {actual_events}." + ) + actual_idx += len(expected_events) + assert actual_idx == len(actual_event_order), ( + f"Expected {actual_idx} events, but got {len(actual_event_order)}: {actual_event_order[actual_idx:]} remain." + ) + + +def poll_assert_event_order( + driver: WebDriver, + exp_event_order: Sequence[str | set[str]], + xpath: str = '//*[@id="event_order"]/p', +) -> None: + """Poll until the actual event order matches the expected event order, accounting for sets in the expected order. + + Args: + driver: WebDriver instance. + exp_event_order: the expected events recorded in the State, where some entries may be sets of events that can occur in any order. + xpath: The XPath to the event order elements. + + Raises: + AssertionError: if the actual event order does not match the expected event order after polling. + """ + n_exp_events = n_expected_events(exp_event_order) + + def _has_number_of_expected_events(): + event_elements = driver.find_elements(By.XPATH, xpath) + return len(event_elements) == n_exp_events + + AppHarness._poll_for(_has_number_of_expected_events) + + event_elements = driver.find_elements(By.XPATH, xpath) + assert_event_order([elem.text for elem in event_elements], exp_event_order) + + +# Type alias for an ordering rule: ((event_a, occurrence_a), (event_b, occurrence_b)). +OrderingRule = tuple[tuple[str, int], tuple[str, int]] + + +def assert_relative_event_order( + actual: list[str], + expected_counts: dict[str, int], + ordering_rules: list[OrderingRule], +) -> None: + """Assert that events satisfy relative ordering constraints. + + Instead of requiring an exact event sequence, this checks that: + 1. Each event appears the expected number of times. + 2. Specific occurrences of events appear before other specific occurrences. + + Args: + actual: the actual events recorded. + expected_counts: mapping of event name to expected occurrence count. + ordering_rules: list of ((event_a, occ_a), (event_b, occ_b)) meaning + the occ_a-th occurrence (0-indexed) of event_a must appear before + the occ_b-th occurrence (0-indexed) of event_b in the actual list. + + Raises: + AssertionError: if any constraint is violated. + """ + from collections import Counter + + actual_counts = Counter(actual) + for event, count in expected_counts.items(): + assert actual_counts[event] == count, ( + f"Expected {count} occurrences of '{event}', got {actual_counts[event]}. Actual: {actual}" + ) + assert sum(expected_counts.values()) == len(actual), ( + f"Expected {sum(expected_counts.values())} total events, got {len(actual)}. Actual: {actual}" + ) + + # Build occurrence index: (event, occ) -> position in actual list + occurrence_indices: dict[tuple[str, int], int] = {} + event_counters: dict[str, int] = {} + for i, event in enumerate(actual): + occ = event_counters.get(event, 0) + occurrence_indices[event, occ] = i + event_counters[event] = occ + 1 + + for (event_a, occ_a), (event_b, occ_b) in ordering_rules: + idx_a = occurrence_indices[event_a, occ_a] + idx_b = occurrence_indices[event_b, occ_b] + assert idx_a < idx_b, ( + f"Expected '{event_a}'[{occ_a}] (pos {idx_a}) before " + f"'{event_b}'[{occ_b}] (pos {idx_b}). Actual: {actual}" + ) + + +def poll_assert_relative_event_order( + driver: WebDriver, + expected_counts: dict[str, int], + ordering_rules: list[OrderingRule], + xpath: str = '//*[@id="event_order"]/p', +) -> None: + """Poll until the expected number of events appear, then assert relative ordering. + + Args: + driver: WebDriver instance. + expected_counts: mapping of event name to expected occurrence count. + ordering_rules: ordering constraints (see assert_relative_event_order). + xpath: The XPath to the event order elements. + """ + n_exp = sum(expected_counts.values()) + + def _has_number_of_expected_events(): + return len(driver.find_elements(By.XPATH, xpath)) == n_exp + + AppHarness._poll_for(_has_number_of_expected_events) + + event_elements = driver.find_elements(By.XPATH, xpath) + assert_relative_event_order( + [elem.text for elem in event_elements], expected_counts, ordering_rules + ) + + class LocalStorage: """Class to access local storage. diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 0e9a929b8fb..36baee0ec8e 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -1,19 +1,31 @@ """Test fixtures.""" import platform +import traceback import uuid -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator, Mapping +from copy import deepcopy +from typing import Any from unittest import mock import pytest +import pytest_asyncio from reflex_base.components.component import CUSTOM_COMPONENTS -from reflex_base.event import EventSpec +from reflex_base.event import Event, EventSpec +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor +from reflex_base.registry import RegistrationContext from reflex.app import App from reflex.experimental.memo import EXPERIMENTAL_MEMOS +from reflex.istate.manager import StateManager +from reflex.istate.manager.disk import StateManagerDisk +from reflex.istate.manager.memory import StateManagerMemory +from reflex.istate.manager.redis import StateManagerRedis from reflex.model import ModelRegistry from reflex.testing import chdir from reflex.utils import prerequisites +from tests.units.mock_redis import mock_redis from .states.upload import SubUploadState, UploadState @@ -213,6 +225,270 @@ def model_registry() -> Generator[type[ModelRegistry], None, None]: ModelRegistry._metadata = None +@pytest_asyncio.fixture( + loop_scope="function", scope="function", params=["in_process", "disk", "redis"] +) +async def state_manager( + request: pytest.FixtureRequest, mock_root_event_context: EventContext +) -> AsyncGenerator[StateManager, None]: + """Instance of state manager parametrized for redis and in-process. + + Args: + request: pytest request object. + mock_root_event_context: The mock root event context to use for the state manager. + + Yields: + A state manager instance + """ + state_manager = StateManager.create() + if request.param == "redis": + if not isinstance(state_manager, StateManagerRedis): + state_manager = StateManagerRedis(redis=mock_redis()) + elif request.param == "disk": + # explicitly NOT using redis + state_manager = StateManagerDisk() + assert not state_manager._states_locks + else: + state_manager = StateManagerMemory() + assert not state_manager._states_locks + + orig_state_manager = mock_root_event_context.state_manager + object.__setattr__(mock_root_event_context, "state_manager", state_manager) + + yield state_manager + + await state_manager.close() + object.__setattr__(mock_root_event_context, "state_manager", orig_state_manager) + + +@pytest.fixture +def mock_event_processor_obj() -> EventProcessor: + """Create an event processor. + + Returns: + A fresh event processor. + """ + + def handle_backend_exception(ex: Exception) -> None: + raise ex + + return EventProcessor( + backend_exception_handler=handle_backend_exception, graceful_shutdown_timeout=1 + ) + + +@pytest.fixture +def mock_base_state_event_processor_obj( + monkeypatch: pytest.MonkeyPatch, +) -> BaseStateEventProcessor: + """Create a BaseState event processor. + + Args: + monkeypatch: pytest monkeypatch fixture. + + Returns: + A fresh BaseState event processor. + """ + monkeypatch.setattr(BaseStateEventProcessor, "_rehydrate", mock.AsyncMock()) + + def handle_backend_exception(ex: Exception) -> None: + formatted_exc = "\n".join(traceback.format_exception(ex)) + pytest.fail(f"Event processor raised an unexpected exception:\n{formatted_exc}") + + return BaseStateEventProcessor( + backend_exception_handler=handle_backend_exception, graceful_shutdown_timeout=1 + ) + + +@pytest.fixture +def emitted_deltas() -> list[tuple[str, Mapping[str, Mapping[str, Any]]]]: + """Create a list to store emitted deltas. + + Returns: + A list to store emitted deltas. + """ + return [] + + +@pytest.fixture +def emitted_events() -> list[tuple[str, tuple[Event, ...]]]: + """Create a list to store emitted events. + + Returns: + A list to store emitted events. + """ + return [] + + +@pytest_asyncio.fixture +async def mock_root_event_context( + mock_base_state_event_processor_obj: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], + emitted_events: list[tuple[str, tuple[Event, ...]]], +) -> AsyncGenerator[EventContext]: + """Create a mock event context. + + Args: + mock_base_state_event_processor_obj: The mock event processor to use for the context's enqueue implementation. + emitted_deltas: The list to store emitted deltas. + emitted_events: The list to store emitted events. + + Yields: + A mock event context. + """ + + async def emit_delta_impl( # noqa: RUF029 + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + """Mock emit delta implementation that records emitted deltas. + + Args: + token: The client token to emit the delta to. + delta: The delta to emit. + """ + emitted_deltas.append((token, delta)) + + async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 + """Mock emit event implementation that records emitted events. + + Args: + token: The client token to emit the events to. + events: The events to emit. + """ + emitted_events.append((token, events)) + + state_manager = StateManagerMemory() + yield EventContext( + token="", + state_manager=state_manager, + enqueue_impl=mock_base_state_event_processor_obj.enqueue_many, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, + ) + await state_manager.close() + + +@pytest.fixture +def mock_event_processor( + mock_root_event_context: EventContext, mock_event_processor_obj: EventProcessor +) -> EventProcessor: + """Create an event processor with a mock root context. + + Set the mock context as the task's current context, and set the processor's + root context to the mock context. + + Events can be queued against the processor via `await + mock_event_processor.enqueue(token, *events)`. + + The `state_manager` fixture is used by the `mock_root_event_context` so any + updates will be reflected in the context's state manager, and any deltas or + frontend events can be checked via the context's `emitted_deltas` and + `emitted_events` attributes. + + Args: + mock_root_event_context: The mock event context to use as the root context for the processor. + mock_event_processor_obj: The mock event processor to use for the processor's enqueue implementation. + + Returns: + An un-started event processor with a mock root context. + """ + mock_event_processor_obj._root_context = mock_root_event_context + return mock_event_processor_obj + + +@pytest.fixture +def mock_base_state_event_processor( + mock_root_event_context: EventContext, + mock_base_state_event_processor_obj: BaseStateEventProcessor, +) -> BaseStateEventProcessor: + """Create a BaseState event processor with a mock root context. + + Set the mock context as the task's current context, and set the processor's + root context to the mock context. + + Events can be queued against the processor via `await + mock_base_state_event_processor.enqueue(token, *events)`. + + The `state_manager` fixture is used by the `mock_root_event_context` so any + updates will be reflected in the context's state manager, and any deltas or + frontend events can be checked via the context's `emitted_deltas` and + `emitted_events` attributes. + + Args: + mock_root_event_context: The mock event context to use as the root context for the processor. + mock_base_state_event_processor_obj: The mock BaseState event processor to use for the processor's enqueue implementation. + + Returns: + An un-started event processor with a mock root context. + """ + mock_base_state_event_processor_obj._root_context = mock_root_event_context + return mock_base_state_event_processor_obj + + +@pytest.fixture +def attached_mock_event_context( + mock_root_event_context: EventContext, token: str +) -> Generator[EventContext, None, None]: + """Fork the mock event context for the given token and attach it. + + Sets the forked context as the current event_context for the duration + of the test, then resets it afterwards. + + Args: + mock_root_event_context: The mock root event context. + token: The client token. + + Yields: + The forked EventContext. + """ + with mock_root_event_context.fork(token=token) as ctx: + yield ctx + + +@pytest_asyncio.fixture +async def attached_mock_base_state_event_processor( + mock_base_state_event_processor: BaseStateEventProcessor, +) -> AsyncGenerator[BaseStateEventProcessor]: + """Fork the mock event context for the given token, attach it, and set the processor's root context to it. + + Args: + mock_base_state_event_processor: The mock BaseState event processor to use for the processor's enqueue implementation. + + Yields: + The mock BaseState event processor with the attached context as its root context. + """ + async with mock_base_state_event_processor as processor: + yield processor + + +@pytest.fixture +def forked_registration_context() -> Generator[RegistrationContext, None, None]: + """Fork the registration context and attach it. + + Sets the forked context as the current registration context for the duration + of the test, then resets it afterwards. + + Yields: + The forked RegistrationContext. + """ + with deepcopy(RegistrationContext.get()) as ctx: + yield ctx + + +@pytest.fixture +def clean_registration_context() -> Generator[RegistrationContext, None, None]: + """Create and attach a clean registration context. + + Sets the new context as the current registration context for the duration + of the test, then resets it afterwards. + + Yields: + The clean RegistrationContext. + """ + with RegistrationContext() as ctx: + yield ctx + + @pytest.fixture def preserve_memo_registries(): """Save and restore global memo registries around a test. diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py index ff5c76de458..f20ea71b052 100644 --- a/tests/units/istate/manager/test_expiration.py +++ b/tests/units/istate/manager/test_expiration.py @@ -8,7 +8,8 @@ import pytest_asyncio from reflex.istate.manager.memory import StateManagerMemory -from reflex.state import BaseState, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState class ExpiringState(BaseState): @@ -45,7 +46,7 @@ async def state_manager_memory() -> AsyncGenerator[StateManagerMemory]: Yields: The memory state manager under test. """ - state_manager = StateManagerMemory(state=ExpiringState, token_expiration=1) + state_manager = StateManagerMemory(token_expiration=1) yield state_manager await state_manager.close() @@ -56,7 +57,7 @@ async def test_memory_state_manager_evicts_expired_state( token: str, ): """Expired states should be removed from the in-memory cache and locks.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) async with state_manager_memory.modify_state(state_token) as state: state.value = 42 @@ -80,7 +81,7 @@ async def test_memory_state_manager_get_state_refreshes_expiration( token: str, ): """Accessing a state should extend its expiration window.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 7 @@ -105,7 +106,7 @@ async def test_memory_state_manager_set_state_refreshes_expiration( token: str, ): """Persisting a state should extend its expiration window.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 17 @@ -130,7 +131,7 @@ async def test_memory_state_manager_multiple_accesses_extend_expiration( token: str, ): """Repeated accesses should keep the state alive until it goes idle.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) expires_at = state_manager_memory._token_expires_at[token] @@ -154,7 +155,7 @@ async def test_memory_state_manager_returns_fresh_state_after_eviction( token: str, ): """A token should get a fresh state after the previous one expires.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 99 @@ -173,7 +174,7 @@ async def test_memory_state_manager_close_cancels_expiration_task( token: str, ): """Closing the manager should cancel the expiration task cleanly.""" - await state_manager_memory.get_state(_substate_key(token, ExpiringState)) + await state_manager_memory.get_state(BaseStateToken(ident=token, cls=ExpiringState)) expiration_task = state_manager_memory._expiration_task assert expiration_task is not None @@ -193,7 +194,7 @@ async def test_memory_state_manager_refreshes_expiration_after_locked_access( token: str, ): """Releasing a long-held state should start a fresh expiration window.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) async with state_manager_memory.modify_state(state_token) as state: state.value = 5 diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index d5fee452c5a..0b37389a337 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -11,7 +11,8 @@ import pytest_asyncio from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState from tests.units.mock_redis import mock_redis, real_redis @@ -51,7 +52,7 @@ async def state_manager_redis( async with real_redis() as redis: if redis is None: redis = mock_redis() - state_manager = StateManagerRedis(state=root_state, redis=redis) + state_manager = StateManagerRedis(redis=redis) test_start = time.monotonic() yield state_manager # None of the tests should have triggered a lock expiration. @@ -104,10 +105,14 @@ async def test_basic_get_set( token = str(uuid.uuid4()) - fresh_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + fresh_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) fresh_state.foo = "baz" fresh_state.count = 42 - await state_manager_redis.set_state(_substate_key(token, root_state), fresh_state) + await state_manager_redis.set_state( + BaseStateToken(ident=token, cls=root_state), fresh_state + ) async def test_modify( @@ -126,19 +131,21 @@ async def test_modify( # Initial modify should set count to 1 async with state_manager_redis.modify_state( - _substate_key(token, root_state) + BaseStateToken(ident=token, cls=root_state) ) as new_state: new_state.count = 1 # Subsequent modify should set count to 2 async with state_manager_redis.modify_state( - _substate_key(token, root_state) + BaseStateToken(ident=token, cls=root_state) ) as new_state: assert isinstance(new_state, root_state) assert new_state.count == 1 new_state.count += 2 - final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + final_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 3 @@ -162,9 +169,7 @@ async def test_modify_oplock( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True @@ -173,7 +178,7 @@ async def test_modify_oplock( # Initial modify should set count to 1 async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: new_state.count = 1 @@ -197,7 +202,7 @@ async def test_modify_oplock( # The second modify should NOT trigger another redis lock async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: new_state.count = 2 assert state_lock_1.locked() @@ -213,7 +218,7 @@ async def test_modify_oplock( # Contend the lock from another state manager event_log_on_update.clear() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: new_state.count = 3 state_lock_2 = state_manager_2._cached_states_locks.get(token) @@ -298,9 +303,7 @@ async def test_oplock_contention_queue( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True @@ -312,7 +315,7 @@ async def test_oplock_contention_queue( async def modify_1(): async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -323,7 +326,7 @@ async def modify_2(): await modify_started.wait() modify_2_started.set() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -333,7 +336,7 @@ async def modify_3(): await modify_started.wait() modify_2_started.set() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -356,14 +359,16 @@ async def modify_3(): await task_3 interim_state = await state_manager_redis.get_state( - _substate_key(token, root_state) + BaseStateToken(ident=token, cls=root_state) ) assert isinstance(interim_state, root_state) assert interim_state.count == 1 await state_manager_2.close() - final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + final_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 3 @@ -393,16 +398,12 @@ async def test_oplock_contention_no_lease( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True - state_manager_3 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_3 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_3._debug_enabled = True state_manager_3._oplock_enabled = True @@ -413,7 +414,7 @@ async def test_oplock_contention_no_lease( async def modify_1(): async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -424,7 +425,7 @@ async def modify_2(): await modify_started.wait() modify_2_started.set() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -434,7 +435,7 @@ async def modify_3(): await modify_started.wait() modify_2_started.set() async with state_manager_3.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -470,7 +471,9 @@ async def modify_3(): await state_manager_2.close() await state_manager_3.close() - final_state = await state_manager_2.get_state(_substate_key(token, root_state)) + final_state = await state_manager_2.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 3 @@ -502,9 +505,7 @@ async def test_oplock_contention_racers( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True lease_1 = None @@ -513,7 +514,7 @@ async def test_oplock_contention_racers( async def modify_1(): nonlocal lease_1 async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: lease_1 = await state_manager_redis._get_local_lease(token) assert isinstance(new_state, root_state) @@ -524,7 +525,7 @@ async def modify_2(): await asyncio.sleep(racer_delay) nonlocal lease_2 async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: lease_2 = await state_manager_2._get_local_lease(token) assert isinstance(new_state, root_state) @@ -573,7 +574,7 @@ async def canceller(): task = asyncio.create_task(canceller()) async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert await state_manager_redis._get_local_lease(token) is None assert isinstance(new_state, root_state) @@ -601,20 +602,20 @@ async def test_oplock_fetch_substate( state_manager_redis._oplock_enabled = True async with state_manager_redis.modify_state( - _substate_key(token, SubState1), + BaseStateToken(ident=token, cls=SubState1), ) as new_state: assert SubState1.get_name() in new_state.substates assert SubState2.get_name() not in new_state.substates async with state_manager_redis.modify_state( - _substate_key(token, SubState2), + BaseStateToken(ident=token, cls=SubState2), ) as new_state: # Both substates should be fetched and cached. assert SubState1.get_name() in new_state.substates assert SubState2.get_name() in new_state.substates async with state_manager_redis.modify_state( - _substate_key(token, SubState1), + BaseStateToken(ident=token, cls=SubState1), ) as new_state: # Both substates should be fetched and cached now. assert SubState1.get_name() in new_state.substates @@ -676,7 +677,7 @@ async def test_oplock_hold_oplock_after_cancel( async def modify(): async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: modify_started.set() assert isinstance(new_state, root_state) @@ -713,7 +714,7 @@ async def modify(): # Modify the state again, this should get a new lock and lease event_log_on_update.clear() async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -730,6 +731,8 @@ async def modify(): await state_manager_redis.close() # Both increments should be present. - final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + final_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 2 diff --git a/tests/units/istate/manager/test_token.py b/tests/units/istate/manager/test_token.py new file mode 100644 index 00000000000..d6f9411b608 --- /dev/null +++ b/tests/units/istate/manager/test_token.py @@ -0,0 +1,184 @@ +"""Tests for StateToken, BaseStateToken, and from_legacy_token.""" + +import io +import pickle + +import pytest + +from reflex.istate.manager.token import BaseStateToken, StateToken + + +def test_state_token_str(): + """__str__ encodes ident and cls into 'ident/module.Class' format.""" + token = StateToken(ident="abc-123", cls=int) + assert str(token) == "abc-123/builtins.int" + + +def test_state_token_str_escapes_slashes(): + """Slashes in ident or cls name are percent-encoded.""" + token = StateToken(ident="a/b", cls=int) + result = str(token) + assert "%2F" in result + assert "/" in result + + +def test_state_token_with_cls(): + """with_cls returns a new token with updated cls, leaving the original unchanged.""" + token = StateToken(ident="tok", cls=int) + new = token.with_cls(bool) + assert new.cls is bool + assert new.ident == "tok" + assert token.cls is int + + +def test_state_token_serialize_deserialize_roundtrip(): + """serialize/deserialize with data= round-trips through pickle.""" + value = {"key": [1, 2, 3]} + data = StateToken.serialize(value) + assert isinstance(data, bytes) + assert StateToken.deserialize(data=data) == value + + +def test_state_token_deserialize_from_fp(): + """Deserialize with fp= reads from a file-like object.""" + value = "hello" + buf = io.BytesIO(pickle.dumps(value)) + assert StateToken.deserialize(fp=buf) == value + + +def test_state_token_deserialize_neither_raises(): + """Deserialize with neither data nor fp raises ValueError.""" + with pytest.raises(ValueError, match="At least one"): + StateToken.deserialize() + + +def test_state_token_deserialize_both_raises(): + """Deserialize with both data and fp raises ValueError.""" + with pytest.raises(ValueError, match="Only one"): + StateToken.deserialize(data=b"data", fp=io.BytesIO()) + + +def test_state_token_get_and_reset_touched_state(): + """Default implementation always returns True.""" + assert StateToken.get_and_reset_touched_state("anything") is True + + +def test_base_state_token_str(clean_registration_context): + """__str__ uses 'ident_full_name' format. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class TokState(BaseState): + pass + + token = BaseStateToken(ident="client-abc", cls=TokState) + result = str(token) + assert result.startswith("client-abc_") + assert TokState.get_full_name() in result + + +def test_base_state_token_with_cls(clean_registration_context): + """with_cls returns a BaseStateToken (not a plain StateToken). + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class A(BaseState): + pass + + class B(BaseState): + pass + + token = BaseStateToken(ident="tok", cls=A) + new = token.with_cls(B) + assert isinstance(new, BaseStateToken) + assert new.cls is B + + +def test_base_state_token_serialize_deserialize(clean_registration_context): + """BaseStateToken serialization uses BaseState._serialize/_deserialize. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class SerState(BaseState): + x: int = 42 + + state = SerState() + data = BaseStateToken.serialize(state) + assert isinstance(data, bytes) + restored = BaseStateToken.deserialize(data=data) + assert isinstance(restored, SerState) + assert restored.x == 42 + + +def test_base_state_token_get_and_reset_touched(clean_registration_context): + """get_and_reset_touched_state returns the touched flag and resets it. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class TouchState(BaseState): + x: int = 0 + + state = TouchState() + state._was_touched = True + assert BaseStateToken.get_and_reset_touched_state(state) is True + assert state._was_touched is False + assert BaseStateToken.get_and_reset_touched_state(state) is False + + +def test_from_legacy_token(clean_registration_context): + """from_legacy_token parses 'ident_state.path' into a BaseStateToken. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class LegacyRoot(BaseState): + pass + + full_name = LegacyRoot.get_full_name() + legacy_str = f"my-client-token_{full_name}" + + token = BaseStateToken.from_legacy_token(legacy_str, root_state=LegacyRoot) + assert token.ident == "my-client-token" + assert token.cls is LegacyRoot + + +def test_from_legacy_token_substate(clean_registration_context): + """from_legacy_token resolves a substate path. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class LegRoot(BaseState): + pass + + class LegChild(LegRoot): + pass + + full_name = LegChild.get_full_name() + legacy_str = f"tok_{full_name}" + + token = BaseStateToken.from_legacy_token(legacy_str, root_state=LegRoot) + assert token.ident == "tok" + assert token.cls is LegChild + + +def test_from_legacy_token_none_root_raises(): + """from_legacy_token with root_state=None raises ValueError.""" + with pytest.raises(ValueError, match="Root state must be provided"): + BaseStateToken.from_legacy_token("tok_some.state", root_state=None) diff --git a/tests/units/istate/test_proxy.py b/tests/units/istate/test_proxy.py index 5fd29725fa9..ec9a923b52a 100644 --- a/tests/units/istate/test_proxy.py +++ b/tests/units/istate/test_proxy.py @@ -4,9 +4,9 @@ import pickle from asyncio import CancelledError from contextlib import asynccontextmanager -from unittest.mock import patch import pytest +from reflex_base.event.context import EventContext import reflex as rx from reflex.istate.proxy import MutableProxy, StateProxy @@ -42,14 +42,15 @@ def test_mutable_proxy_pickle_preserves_object_identity(): assert unpickled["direct"][0] is unpickled["proxied"][0] -@pytest.mark.usefixtures("mock_app") @pytest.mark.asyncio -async def test_state_proxy_recovery(): +async def test_state_proxy_recovery( + attached_mock_event_context: EventContext, monkeypatch: pytest.MonkeyPatch +): """Ensure that `async with self` can be re-entered after a lock issue.""" state = ProxyTestState() state_proxy = StateProxy(state) - with patch("reflex.app.App.modify_state") as mock_modify_state: + with monkeypatch.context() as m: @asynccontextmanager async def mock_modify_state_context(*args, **kwargs): # noqa: RUF029 @@ -57,7 +58,11 @@ async def mock_modify_state_context(*args, **kwargs): # noqa: RUF029 raise CancelledError(msg) yield - mock_modify_state.side_effect = mock_modify_state_context + m.setattr( + attached_mock_event_context.state_manager, + "modify_state", + mock_modify_state_context, + ) with pytest.raises(CancelledError, match="Simulated lock issue"): async with state_proxy: diff --git a/tests/units/middleware/conftest.py b/tests/units/middleware/conftest.py index 4b531ecff92..d0895608a38 100644 --- a/tests/units/middleware/conftest.py +++ b/tests/units/middleware/conftest.py @@ -6,7 +6,6 @@ def create_event(name): return Event( - token="", name=name, router_data={ "pathname": "/", diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py index 2ac1fadd139..ed437e7e4e0 100644 --- a/tests/units/middleware/test_hydrate_middleware.py +++ b/tests/units/middleware/test_hydrate_middleware.py @@ -1,7 +1,7 @@ from __future__ import annotations import pytest -from pytest_mock import MockerFixture +from reflex_base.registry import RegistrationContext from reflex.app import App from reflex.middleware.hydrate_middleware import HydrateMiddleware @@ -31,15 +31,17 @@ def hydrate_middleware() -> HydrateMiddleware: @pytest.mark.asyncio -async def test_preprocess_no_events(hydrate_middleware, event1, mocker: MockerFixture): +async def test_preprocess_no_events( + hydrate_middleware, event1, clean_registration_context: RegistrationContext +): """Test that app without on_load is processed correctly. Args: hydrate_middleware: Instance of HydrateMiddleware event1: An Event. - mocker: pytest mock object. + clean_registration_context: The registration context fixture, which is cleared before each test. """ - mocker.patch("reflex.state.State.class_subclasses", {TestState}) + clean_registration_context.register_base_state(TestState) state = State() update = await hydrate_middleware.preprocess( app=App(_state=State), diff --git a/tests/units/reflex_base/__init__.py b/tests/units/reflex_base/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/reflex_base/context/__init__.py b/tests/units/reflex_base/context/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/reflex_base/context/test_base.py b/tests/units/reflex_base/context/test_base.py new file mode 100644 index 00000000000..11db7963159 --- /dev/null +++ b/tests/units/reflex_base/context/test_base.py @@ -0,0 +1,94 @@ +"""Tests for BaseContext.""" + +import dataclasses + +import pytest +from reflex_base.context.base import BaseContext + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class _TestContext(BaseContext): + """Minimal BaseContext subclass for unit testing.""" + + label: str = "test" + + +def test_get_without_set_raises(): + """get() raises LookupError when no context is set.""" + with pytest.raises(LookupError): + _TestContext.get() + + +def test_set_and_get(): + """set() makes the context retrievable via get().""" + ctx = _TestContext(label="a") + token = _TestContext.set(ctx) + try: + assert _TestContext.get() is ctx + finally: + _TestContext.reset(token) + + +def test_reset_restores_previous(): + """reset() restores the previously active context.""" + outer = _TestContext(label="outer") + outer_tok = _TestContext.set(outer) + try: + inner = _TestContext(label="inner") + inner_tok = _TestContext.set(inner) + assert _TestContext.get() is inner + _TestContext.reset(inner_tok) + assert _TestContext.get() is outer + finally: + _TestContext.reset(outer_tok) + + +def test_context_manager_enter_exit(): + """__enter__ sets the context and __exit__ removes it.""" + ctx = _TestContext(label="cm") + with ctx as entered: + assert entered is ctx + assert _TestContext.get() is ctx + with pytest.raises(LookupError): + _TestContext.get() + + +def test_context_manager_nesting(): + """Nested context managers restore the outer context on inner exit.""" + outer = _TestContext(label="outer") + inner = _TestContext(label="inner") + with outer: + assert _TestContext.get().label == "outer" + with inner: + assert _TestContext.get().label == "inner" + assert _TestContext.get().label == "outer" + + +def test_double_enter_raises(): + """Entering the same context instance twice raises RuntimeError.""" + ctx = _TestContext(label="double") + with ctx, pytest.raises(RuntimeError, match="already attached"): + ctx.__enter__() + + +def test_ensure_context_attached(): + """ensure_context_attached raises when not entered, succeeds when entered.""" + ctx = _TestContext(label="ensure") + with pytest.raises(RuntimeError, match="must be entered"): + ctx.ensure_context_attached() + with ctx: + ctx.ensure_context_attached() + + +def test_subclasses_have_independent_context_vars(): + """Two BaseContext subclasses do not share their ContextVar.""" + + @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) + class _OtherContext(BaseContext): + value: int = 0 + + ctx_a = _TestContext(label="a") + ctx_b = _OtherContext(value=42) + with ctx_a, ctx_b: + assert _TestContext.get().label == "a" + assert _OtherContext.get().value == 42 diff --git a/tests/units/reflex_base/event/__init__.py b/tests/units/reflex_base/event/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/reflex_base/event/processor/test_base_state_processor.py b/tests/units/reflex_base/event/processor/test_base_state_processor.py new file mode 100644 index 00000000000..e414369a40e --- /dev/null +++ b/tests/units/reflex_base/event/processor/test_base_state_processor.py @@ -0,0 +1,152 @@ +"""Tests for BaseStateEventProcessor, specifically the _rehydrate path.""" + +import traceback +from collections.abc import Mapping +from typing import Any + +import pytest +import pytest_asyncio +from reflex_base.constants import CompileVars +from reflex_base.constants.state import FIELD_MARKER +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor +from reflex_base.registry import RegistrationContext + +from reflex import event +from reflex.app import App +from reflex.event import Event +from reflex.istate.manager.memory import StateManagerMemory +from reflex.state import OnLoadInternalState, State + + +@pytest.fixture +def _real_base_state_processor_obj() -> BaseStateEventProcessor: + """A BaseStateEventProcessor with real (unmocked) _rehydrate. + + Returns: + A fresh BaseStateEventProcessor instance. + """ + + def handle_backend_exception(ex: Exception) -> None: + formatted_exc = "\n".join(traceback.format_exception(ex)) + pytest.fail(f"Event processor raised an unexpected exception:\n{formatted_exc}") + + return BaseStateEventProcessor( + backend_exception_handler=handle_backend_exception, + graceful_shutdown_timeout=2, + ) + + +@pytest.fixture +def emitted_deltas() -> list[tuple[str, Mapping[str, Mapping[str, Any]]]]: + """List to capture emitted deltas. + + Returns: + An empty list for collecting deltas. + """ + return [] + + +@pytest.fixture +def emitted_events() -> list[tuple[str, tuple[Event, ...]]]: + """List to capture emitted events. + + Returns: + An empty list for collecting events. + """ + return [] + + +@pytest_asyncio.fixture +async def real_base_state_processor( + _real_base_state_processor_obj: BaseStateEventProcessor, + emitted_deltas: list, + emitted_events: list, + clean_registration_context: RegistrationContext, +): + """A fully wired BaseStateEventProcessor with real _rehydrate. + + Yields the processor (not yet started). The test must use ``async with processor`` to + control the lifecycle and assert on emitted deltas after stop. + + Args: + _real_base_state_processor_obj: The unmocked processor instance. + emitted_deltas: List to capture emitted deltas. + emitted_events: List to capture emitted events. + clean_registration_context: Isolated registration context for the test. + + Yields: + The configured but not-yet-started BaseStateEventProcessor. + """ + clean_registration_context.register_base_state(OnLoadInternalState) + state_manager = StateManagerMemory() + + async def emit_delta_impl( # noqa: RUF029 + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + emitted_deltas.append((token, delta)) + + async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 + emitted_events.append((token, events)) + + root_ctx = EventContext( + token="", + state_manager=state_manager, + enqueue_impl=_real_base_state_processor_obj.enqueue_many, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, + ) + _real_base_state_processor_obj._root_context = root_ctx + + yield _real_base_state_processor_obj + + await state_manager.close() + + +async def test_rehydrate_sets_is_hydrated_on_fresh_token( + app_module_mock, + real_base_state_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], + token: str, +): + """A non-hydrate event against a fresh token triggers _rehydrate, emitting is_hydrated=True. + + When a token has never been seen before (no router_data on the state), + and the event is not itself the hydrate event, the processor calls + _rehydrate which runs State.hydrate. With no on_load events defined, + hydrate sets is_hydrated=True directly. + + Args: + app_module_mock: The mock app module fixture. + real_base_state_processor: The unmocked BaseStateEventProcessor. + emitted_deltas: List to capture emitted deltas. + token: The client token. + """ + + class MyState(State): + @event + def noop(self): + pass + + OnLoadInternalState._app_ref = None + app = app_module_mock.app = App() + assert real_base_state_processor._root_context is not None + app._state_manager = real_base_state_processor._root_context.state_manager + + async with real_base_state_processor as processor: + await processor.enqueue( + token, + Event.from_event_type(MyState.noop())[0], + ) + await processor.join(1) + + state_name = State.get_full_name() + is_hydrated_key = CompileVars.IS_HYDRATED + FIELD_MARKER + hydrated_deltas = [ + d + for _, d in emitted_deltas + if state_name in d and d[state_name].get(is_hydrated_key) is True + ] + assert len(hydrated_deltas) >= 1, ( + f"Expected at least one delta with is_hydrated=True, got deltas: {emitted_deltas}" + ) diff --git a/tests/units/reflex_base/event/processor/test_event_processor.py b/tests/units/reflex_base/event/processor/test_event_processor.py new file mode 100644 index 00000000000..de1ea4dcb23 --- /dev/null +++ b/tests/units/reflex_base/event/processor/test_event_processor.py @@ -0,0 +1,592 @@ +"""Tests for EventProcessor lifecycle, task management, and error handling.""" + +import asyncio +import contextlib +from typing import Any + +import pytest +from reflex_base.event.context import EventContext +from reflex_base.event.processor.event_processor import EventProcessor, QueueShutDown +from reflex_base.registry import RegistrationContext + +from reflex.event import Event, EventHandler + +# Module-level log so event handlers can record what happened. +_CALL_LOG: list[dict[str, Any]] = [] + + +async def _noop_handler(): + """A handler that does nothing.""" + + +async def _slow_handler(delay: float = 0.5): + """A handler that sleeps for *delay* seconds. + + Args: + delay: How long to sleep in seconds. + """ + await asyncio.sleep(delay) + + +async def _error_handler(): # noqa: RUF029 + """A handler that always raises.""" + raise RuntimeError("boom") # noqa: EM101 + + +async def _logging_handler(value: str = "default"): # noqa: RUF029 + """A handler that records its invocation. + + Args: + value: The value to log. + """ + _CALL_LOG.append({"value": value}) + + +async def _chaining_handler(): + """A handler that enqueues a logging event via the current EventContext.""" + ctx = EventContext.get() + await ctx.enqueue( + Event.from_event_type(logging_event("chained"))[0], + ) + + +async def _delta_handler(): + """A handler that emits a single delta.""" + ctx = EventContext.get() + await ctx.emit_delta({"state": {"x": 1}}) + + +async def _multi_delta_handler(): + """A handler that emits multiple deltas with a small pause between them.""" + ctx = EventContext.get() + for i in range(3): + await ctx.emit_delta({"state": {"i": i}}) + await asyncio.sleep(0.01) + + +async def _slow_logging_handler(value: str = "default"): + """A slow logging handler that pauses before recording. + + Args: + value: The value to log. + """ + await asyncio.sleep(0.05) + _CALL_LOG.append({"value": value}) + + +async def _multi_chaining_handler(): + """A handler that enqueues three slow logging events in sequence.""" + ctx = EventContext.get() + for label in ("first", "second", "third"): + await ctx.enqueue( + Event.from_event_type(slow_logging_event(label))[0], + ) + + +async def _background_then_normal_handler(): + """A handler that enqueues a background event followed by a normal slow event.""" + ctx = EventContext.get() + await ctx.enqueue(Event.from_event_type(background_slow_logging_event("bg"))[0]) + await ctx.enqueue(Event.from_event_type(slow_logging_event("normal"))[0]) + + +async def _error_then_logging_handler(): + """A handler that enqueues an error event followed by a logging event.""" + ctx = EventContext.get() + await ctx.enqueue(Event.from_event_type(error_event())[0]) + await ctx.enqueue(Event.from_event_type(logging_event("after_chain_error"))[0]) + + +async def _background_slow_logging_handler(value: str = "default"): + """A background version of the slow logging handler. + + Args: + value: The value to log. + """ + await asyncio.sleep(0.05) + _CALL_LOG.append({"value": value}) + + +_background_slow_logging_handler._reflex_background_task = True # type: ignore[attr-defined] + + +noop_event = EventHandler(fn=_noop_handler) +slow_event = EventHandler(fn=_slow_handler) +error_event = EventHandler(fn=_error_handler) +logging_event = EventHandler(fn=_logging_handler) +chaining_event = EventHandler(fn=_chaining_handler) +delta_event = EventHandler(fn=_delta_handler) +multi_delta_event = EventHandler(fn=_multi_delta_handler) +slow_logging_event = EventHandler(fn=_slow_logging_handler) +multi_chaining_event = EventHandler(fn=_multi_chaining_handler) +background_slow_logging_event = EventHandler(fn=_background_slow_logging_handler) +background_then_normal_event = EventHandler(fn=_background_then_normal_handler) +error_then_logging_event = EventHandler(fn=_error_then_logging_handler) + + +@pytest.fixture(autouse=True) +def _register_handlers(forked_registration_context: RegistrationContext): + """Register all test event handlers and clear the call log. + + Args: + forked_registration_context: Isolated registration context for the test. + """ + _CALL_LOG.clear() + for handler in ( + noop_event, + slow_event, + error_event, + logging_event, + chaining_event, + delta_event, + multi_delta_event, + slow_logging_event, + multi_chaining_event, + background_slow_logging_event, + background_then_normal_event, + error_then_logging_event, + ): + RegistrationContext.register_event_handler(handler) + + +@pytest.fixture +def processor() -> EventProcessor: + """A bare EventProcessor with no backend_exception_handler. + + Returns: + A fresh EventProcessor instance. + """ + return EventProcessor(graceful_shutdown_timeout=2) + + +def test_configure_once(processor: EventProcessor): + """Calling configure() twice raises RuntimeError. + + Args: + processor: The event processor fixture. + """ + processor.configure() + with pytest.raises(RuntimeError, match="already configured"): + processor.configure() + + +async def test_start_before_configure(processor: EventProcessor): + """Starting before configure raises RuntimeError. + + Args: + processor: The event processor fixture. + """ + with pytest.raises(RuntimeError, match="not configured"): + await processor.start() + + +async def test_start_twice(processor: EventProcessor): + """Starting a second time raises RuntimeError. + + Args: + processor: The event processor fixture. + """ + processor.configure() + await processor.start() + try: + with pytest.raises(RuntimeError, match="already started"): + await processor.start() + finally: + await processor.stop() + + +async def test_stop_idempotent(processor: EventProcessor): + """Stopping an already-stopped processor does not error. + + Args: + processor: The event processor fixture. + """ + processor.configure() + await processor.start() + await processor.stop() + await processor.stop() + + +async def test_async_context_manager(processor: EventProcessor): + """Entering/exiting via ``async with`` starts and stops the processor. + + Args: + processor: The event processor fixture. + """ + processor.configure() + async with processor as ep: + assert ep._queue is not None + assert ep._queue is None + assert ep._queue_task is None + + +async def test_enqueue_after_stop_raises(processor: EventProcessor): + """Enqueueing after stop raises because the queue is gone. + + Args: + processor: The event processor fixture. + """ + processor.configure() + async with processor: + pass + with pytest.raises(QueueShutDown, match="not running"): + await processor.enqueue("tok", Event.from_event_type(noop_event())[0]) + + +async def test_enqueue_before_start_raises(processor: EventProcessor): + """Enqueueing before start raises because the queue doesn't exist. + + Args: + processor: The event processor fixture. + """ + processor.configure() + with pytest.raises(QueueShutDown, match="not running"): + await processor.enqueue("tok", Event.from_event_type(noop_event())[0]) + + +async def test_events_are_processed( + mock_event_processor: EventProcessor, + emitted_deltas: list, + token: str, +): + """Events enqueued are actually processed. + + Args: + mock_event_processor: The event processor with mock root context. + emitted_deltas: List to capture emitted deltas. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, Event.from_event_type(logging_event("hello"))[0]) + assert _CALL_LOG == [{"value": "hello"}] + + +async def test_enqueue_returns_future( + mock_event_processor: EventProcessor, + token: str, +): + """enqueue() returns a Future that resolves when the task finishes. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + future = await ep.enqueue(token, Event.from_event_type(noop_event())[0]) + assert isinstance(future, asyncio.Future) + assert future.done() + + +async def test_tasks_cleared_after_stop( + mock_event_processor: EventProcessor, + token: str, +): + """After stop(), the internal _tasks dict is empty. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, Event.from_event_type(noop_event())[0]) + assert ep._tasks == {} + + +async def test_futures_cleared_after_stop( + mock_event_processor: EventProcessor, + token: str, +): + """After stop(), the internal _futures dict is empty. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, Event.from_event_type(noop_event())[0]) + assert ep._futures == {} + + +async def test_slow_tasks_cancelled_on_stop(processor: EventProcessor): + """Tasks that haven't finished by the graceful timeout are cancelled. + + Args: + processor: The event processor fixture. + """ + processor.graceful_shutdown_timeout = 0 + processor.configure() + async with processor as ep: + future = await ep.enqueue("tok", Event.from_event_type(slow_event(10.0))[0]) + assert future.cancelled() + assert ep._tasks == {} + + +async def test_multiple_futures_cancelled_on_stop(processor: EventProcessor): + """Unresolved futures are cancelled during stop. + + Args: + processor: The event processor fixture. + """ + processor.graceful_shutdown_timeout = 0 + processor.configure() + async with processor as ep: + f1 = await ep.enqueue("t1", Event.from_event_type(slow_event(10.0))[0]) + f2 = await ep.enqueue("t2", Event.from_event_type(slow_event(10.0))[0]) + for f in (f1, f2): + assert f.done() + assert ep._futures == {} + + +async def test_cancel_future_before_task_starts( + mock_event_processor: EventProcessor, + token: str, +): + """Cancelling the future before the task starts skips processing. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + future = await ep.enqueue(token, Event.from_event_type(slow_event(10.0))[0]) + future.cancel() + await asyncio.sleep(0.05) + assert ep._tasks == {} + + +async def test_cancel_future_cancels_running_task( + mock_event_processor: EventProcessor, + token: str, +): + """Cancelling the future cancels an already-running task. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + future = await ep.enqueue(token, Event.from_event_type(slow_event(10.0))[0]) + await asyncio.sleep(0.05) + future.cancel() + await asyncio.sleep(0.05) + assert ep._tasks == {} + + +async def test_exception_propagated_to_future( + processor: EventProcessor, + token: str, +): + """An exception in the handler is set on the future. + + Args: + processor: The event processor fixture. + token: The client token. + """ + processor.configure() + async with processor as ep: + future = await ep.enqueue(token, Event.from_event_type(error_event())[0]) + assert future.done() + with pytest.raises(RuntimeError, match="boom"): + future.result() + + +async def test_backend_exception_handler_called(token: str): + """The backend_exception_handler receives the exception. + + Args: + token: The client token. + """ + caught: list[Exception] = [] + + def _catch(ex: Exception) -> None: + caught.append(ex) + + ep = EventProcessor(backend_exception_handler=_catch, graceful_shutdown_timeout=2) + ep.configure() + async with ep: + await ep.enqueue(token, Event.from_event_type(error_event())[0]) + assert len(caught) == 1 + assert isinstance(caught[0], RuntimeError) + + +async def test_error_does_not_stop_queue( + processor: EventProcessor, + token: str, +): + """A failing event does not prevent subsequent events from processing. + + Args: + processor: The event processor fixture. + token: The client token. + """ + processor.configure() + async with processor as ep: + await ep.enqueue(token, Event.from_event_type(error_event())[0]) + await ep.enqueue(token, Event.from_event_type(logging_event("after_error"))[0]) + assert _CALL_LOG == [{"value": "after_error"}] + + +async def test_chained_event_processed(token: str): + """An event handler that enqueues another event via ctx.enqueue succeeds. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + await ep.enqueue(token, Event.from_event_type(chaining_event())[0]) + assert _CALL_LOG == [{"value": "chained"}] + + +async def test_join_when_not_started(processor: EventProcessor): + """join() when not started is a no-op (queue is None). + + Args: + processor: The event processor fixture. + """ + processor.configure() + await processor.join(timeout=1) + + +async def test_join_completes_after_processing( + mock_event_processor: EventProcessor, + token: str, +): + """join() returns once all queued entries have been dequeued. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, Event.from_event_type(noop_event())[0]) + await ep.join(timeout=5) + + +async def test_stream_delta_yields_single_delta(token: str): + """enqueue_stream_delta yields a delta emitted by the handler. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(delta_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [{"state": {"x": 1}}] + + +async def test_stream_delta_yields_multiple_deltas(token: str): + """enqueue_stream_delta yields all deltas in order. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(multi_delta_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [ + {"state": {"i": 0}}, + {"state": {"i": 1}}, + {"state": {"i": 2}}, + ] + + +async def test_stream_delta_noop_handler_yields_nothing(token: str): + """enqueue_stream_delta with a handler that emits no deltas yields nothing. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(noop_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [] + + +async def test_stream_delta_not_configured_raises(): + """enqueue_stream_delta raises RuntimeError if processor is not configured.""" + ep = EventProcessor() + with pytest.raises(RuntimeError, match="not configured"): + async for _ in ep.enqueue_stream_delta("tok", Event(name="x", payload={})): + pass + + +async def test_sequential_chained_events_run_in_order(token: str): + """Chained events enqueued by a handler run in the order they were enqueued. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(multi_chaining_event())[0] + ) + await future.wait_all() + assert [entry["value"] for entry in _CALL_LOG] == ["first", "second", "third"] + + +async def test_futures_cleaned_up_after_chained_events(token: str): + """All futures are removed from _futures after chained events complete. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(multi_chaining_event())[0] + ) + await future.wait_all() + assert ep._futures == {} + + +async def test_background_event_does_not_block_sequential_sibling(token: str): + """A background event enqueued before a sequential sibling does not delay it. + + The background event (sequential=False) should execute concurrently while + the normal sibling is free to start without waiting for the background + event to finish first. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(background_then_normal_event())[0] + ) + await future.wait_all() + # Both events should have been processed regardless of order. + values = {entry["value"] for entry in _CALL_LOG} + assert values == {"bg", "normal"} + + +async def test_sequential_chain_continues_after_error(token: str): + """A sequential chained event still runs when the preceding sibling raised an exception. + + The error in the first chained event must not block the second chained + event from executing. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(error_then_logging_event())[0] + ) + with contextlib.suppress(Exception): + await future.wait_all() + assert _CALL_LOG == [{"value": "after_chain_error"}] diff --git a/tests/units/reflex_base/event/processor/test_future.py b/tests/units/reflex_base/event/processor/test_future.py new file mode 100644 index 00000000000..dec13eb7cd8 --- /dev/null +++ b/tests/units/reflex_base/event/processor/test_future.py @@ -0,0 +1,301 @@ +"""Tests for EventFuture.""" + +import asyncio + +import pytest +from reflex_base.event.processor.future import EventFuture + + +@pytest.mark.asyncio +async def test_create_uses_running_loop(): # noqa: RUF029 + """EventFuture() defaults to the running event loop.""" + running_loop = asyncio.get_running_loop() + f = EventFuture(txid="f") + assert isinstance(f, EventFuture) + assert f.get_loop() is running_loop + assert f.children == [] + assert not f.done() + + +@pytest.mark.asyncio +async def test_create_with_explicit_loop(): # noqa: RUF029 + """EventFuture(loop=...) uses the given (non-default) loop.""" + other_loop = asyncio.new_event_loop() + try: + f = EventFuture(txid="f", loop=other_loop) + assert isinstance(f, EventFuture) + assert f.get_loop() is other_loop + assert f.get_loop() is not asyncio.get_running_loop() + finally: + other_loop.close() + + +@pytest.mark.asyncio +async def test_add_child_multiple(): # noqa: RUF029 + """add_child can be called multiple times.""" + parent = EventFuture(txid="parent") + children = [EventFuture(txid=f"c{i}") for i in range(3)] + for c in children: + parent.add_child(c) + assert parent.children == children + + +@pytest.mark.asyncio +async def test_add_child_to_done_future_raises(): # noqa: RUF029 + """add_child raises RuntimeError if the parent future is already done.""" + parent = EventFuture(txid="parent") + parent.set_result(None) + child = EventFuture(txid="child") + with pytest.raises(RuntimeError, match="already done"): + parent.add_child(child) + + +@pytest.mark.asyncio +async def test_add_child_to_cancelled_future_raises(): # noqa: RUF029 + """add_child raises RuntimeError if the parent future is cancelled.""" + parent = EventFuture(txid="parent") + parent.cancel() + child = EventFuture(txid="child") + with pytest.raises(RuntimeError, match="already done"): + parent.add_child(child) + + +@pytest.mark.asyncio +async def test_all_done_no_children(): # noqa: RUF029 + """all_done is True when the future is resolved and has no children.""" + f = EventFuture(txid="f") + assert not f.all_done() + f.set_result(42) + assert f.all_done() + + +@pytest.mark.asyncio +async def test_all_done_with_pending_child(): # noqa: RUF029 + """all_done is False when a child is still pending.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + parent.set_result(None) + assert not parent.all_done() + child.set_result(None) + assert parent.all_done() + + +@pytest.mark.asyncio +async def test_all_done_nested(): # noqa: RUF029 + """all_done checks the full descendant tree.""" + root = EventFuture(txid="root") + child = EventFuture(txid="child") + grandchild = EventFuture(txid="grandchild") + root.add_child(child) + child.add_child(grandchild) + + root.set_result(None) + child.set_result(None) + # grandchild still pending + assert not root.all_done() + + grandchild.set_result(None) + assert root.all_done() + + +@pytest.mark.asyncio +async def test_all_done_with_cancelled_child(): # noqa: RUF029 + """all_done is True when all children are cancelled (done).""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + parent.set_result(None) + child.cancel() + assert parent.all_done() + + +@pytest.mark.asyncio +async def test_all_done_with_exception_child(): # noqa: RUF029 + """all_done is True when a child has an exception (still done).""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + parent.set_result(None) + child.set_exception(ValueError("boom")) + assert parent.all_done() + + +@pytest.mark.asyncio +async def test_wait_all_returns_result(): + """wait_all returns the result of the root future.""" + f = EventFuture(txid="f") + f.set_result(42) + result = await f.wait_all() + assert result == 42 + + +@pytest.mark.asyncio +async def test_wait_all_waits_for_children(): + """wait_all waits for all children to complete.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + + async def resolve_later(): + await asyncio.sleep(0.01) + child.set_result("child_done") + + parent.set_result("parent_done") + task = asyncio.create_task(resolve_later()) + result = await parent.wait_all() + assert result == "parent_done" + assert child.done() + await task + + +@pytest.mark.asyncio +async def test_wait_all_waits_for_nested_children(): + """wait_all waits for grandchildren too.""" + root = EventFuture(txid="root") + child = EventFuture(txid="child") + grandchild = EventFuture(txid="grandchild") + root.add_child(child) + child.add_child(grandchild) + + async def resolve_chain(): + await asyncio.sleep(0.01) + child.set_result(None) + await asyncio.sleep(0.01) + grandchild.set_result(None) + + root.set_result("root") + task = asyncio.create_task(resolve_chain()) + result = await root.wait_all() + assert result == "root" + assert grandchild.done() + await task + + +@pytest.mark.asyncio +async def test_wait_all_suppresses_child_exceptions(): + """wait_all suppresses exceptions from children.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + + parent.set_result("ok") + child.set_exception(ValueError("child error")) + + # Should not raise + result = await parent.wait_all() + assert result == "ok" + + +@pytest.mark.asyncio +async def test_wait_all_suppresses_child_cancellation(): + """wait_all suppresses CancelledError from children.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + + parent.set_result("ok") + child.cancel() + + result = await parent.wait_all() + assert result == "ok" + + +@pytest.mark.asyncio +async def test_wait_all_children_added_during_iteration(): + """wait_all picks up children added while iterating (index-based walk).""" + parent = EventFuture(txid="parent") + child1 = EventFuture(txid="child1") + parent.add_child(child1) + parent.set_result("done") + + # child2 will be added to child1 after child1 resolves, + # simulating a chained event that enqueues more events. + child2 = EventFuture(txid="child2") + + async def resolve_and_chain(): + await asyncio.sleep(0.01) + child1.add_child(child2) + child1.set_result(None) + await asyncio.sleep(0.01) + child2.set_result(None) + + task = asyncio.create_task(resolve_and_chain()) + await parent.wait_all() + assert child2.done() + await task + + +@pytest.mark.asyncio +async def test_cancel_no_children(): # noqa: RUF029 + """Cancel cancels the future itself.""" + f = EventFuture(txid="f") + assert f.cancel() + assert f.cancelled() + + +@pytest.mark.asyncio +async def test_cancel_cascades_to_children(): # noqa: RUF029 + """Cancel propagates to all children.""" + parent = EventFuture(txid="parent") + child1 = EventFuture(txid="child1") + child2 = EventFuture(txid="child2") + parent.add_child(child1) + parent.add_child(child2) + + parent.cancel() + assert parent.cancelled() + assert child1.cancelled() + assert child2.cancelled() + + +@pytest.mark.asyncio +async def test_cancel_cascades_to_grandchildren(): # noqa: RUF029 + """Cancel propagates through the full descendant tree.""" + root = EventFuture(txid="root") + child = EventFuture(txid="child") + grandchild = EventFuture(txid="grandchild") + root.add_child(child) + child.add_child(grandchild) + + root.cancel() + assert grandchild.cancelled() + + +@pytest.mark.asyncio +async def test_cancel_with_message(): # noqa: RUF029 + """Cancel passes the message to children.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + + parent.cancel("shutting down") + assert parent.cancelled() + assert child.cancelled() + with pytest.raises(asyncio.CancelledError, match="shutting down"): + parent.result() + with pytest.raises(asyncio.CancelledError, match="shutting down"): + child.result() + + +@pytest.mark.asyncio +async def test_cancel_already_done_child(): # noqa: RUF029 + """Cancel on a parent does not fail if a child is already resolved.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") + parent.add_child(child) + child.set_result("already done") + + parent.cancel() + assert parent.cancelled() + # child was already done, cancel returns False but doesn't raise + assert not child.cancelled() + assert child.result() == "already done" + + +@pytest.mark.asyncio +async def test_cancel_already_done_parent_returns_false(): # noqa: RUF029 + """Cancel returns False if the parent is already resolved.""" + f = EventFuture(txid="f") + f.set_result(None) + assert not f.cancel() diff --git a/tests/units/reflex_base/event/processor/test_timeout.py b/tests/units/reflex_base/event/processor/test_timeout.py new file mode 100644 index 00000000000..54f0827ea4a --- /dev/null +++ b/tests/units/reflex_base/event/processor/test_timeout.py @@ -0,0 +1,29 @@ +"""Tests for DrainTimeoutManager.""" + +import time + +from reflex_base.event.processor.timeout import DrainTimeoutManager + + +def test_drain_timeout_no_timeout(): + """DrainTimeoutManager with no timeout returns 0.""" + dtm = DrainTimeoutManager.with_timeout(None) + with dtm as remaining: + assert remaining == 0 + + +def test_drain_timeout_decreases(): + """DrainTimeoutManager remaining time decreases across re-entries.""" + dtm = DrainTimeoutManager.with_timeout(10.0) + with dtm as first: + assert 9.5 < first <= 10.0 + time.sleep(0.1) + with dtm as second: + assert second < first + + +def test_drain_timeout_expired_returns_zero(): + """DrainTimeoutManager with an already-expired timeout returns 0.""" + dtm = DrainTimeoutManager.with_timeout(0) + with dtm as remaining: + assert remaining == 0 diff --git a/tests/units/reflex_base/event/test_context.py b/tests/units/reflex_base/event/test_context.py new file mode 100644 index 00000000000..484ec696e21 --- /dev/null +++ b/tests/units/reflex_base/event/test_context.py @@ -0,0 +1,84 @@ +"""Tests for EventContext.""" + +from unittest import mock + +from reflex_base.event.context import EventContext + + +def test_fork_creates_child(mock_root_event_context: EventContext): + """fork() creates a child context with a new txid and shared impls. + + Args: + mock_root_event_context: The root event context fixture. + """ + child = mock_root_event_context.fork(token="child-tok") + assert child.token == "child-tok" + assert child.parent_txid == mock_root_event_context.txid + assert child.txid != mock_root_event_context.txid + assert child.state_manager is mock_root_event_context.state_manager + assert child.enqueue_impl is mock_root_event_context.enqueue_impl + + +def test_fork_inherits_token(mock_root_event_context: EventContext): + """fork() without token= inherits the parent's token. + + Args: + mock_root_event_context: The root event context fixture. + """ + child = mock_root_event_context.fork() + assert child.token == mock_root_event_context.token + + +async def test_emit_delta(mock_root_event_context: EventContext, emitted_deltas: list): + """emit_delta records the delta via emit_delta_impl. + + Args: + mock_root_event_context: The root event context fixture. + emitted_deltas: List to capture emitted deltas. + """ + ctx = mock_root_event_context.fork(token="tok") + delta = {"state": {"x": 1}} + await ctx.emit_delta(delta) + assert emitted_deltas == [("tok", delta)] + + +async def test_emit_event(mock_root_event_context: EventContext, emitted_events: list): + """emit_event records the event via emit_event_impl. + + Args: + mock_root_event_context: The root event context fixture. + emitted_events: List to capture emitted events. + """ + from reflex.event import Event + + ctx = mock_root_event_context.fork(token="tok") + ev = Event(name="test", payload={}) + await ctx.emit_event(ev) + assert len(emitted_events) == 1 + assert emitted_events[0][0] == "tok" + + +async def test_emit_delta_noop_when_no_impl(): + """emit_delta is a no-op when emit_delta_impl is None.""" + from reflex.istate.manager.memory import StateManagerMemory + + ctx = EventContext( + token="t", + state_manager=StateManagerMemory(), + enqueue_impl=mock.AsyncMock(), + emit_delta_impl=None, + ) + await ctx.emit_delta({"s": {"k": "v"}}) + + +async def test_emit_event_noop_when_no_impl(): + """emit_event is a no-op when emit_event_impl is None.""" + from reflex.istate.manager.memory import StateManagerMemory + + ctx = EventContext( + token="t", + state_manager=StateManagerMemory(), + enqueue_impl=mock.AsyncMock(), + emit_event_impl=None, + ) + await ctx.emit_event() diff --git a/tests/units/reflex_base/test_registry.py b/tests/units/reflex_base/test_registry.py new file mode 100644 index 00000000000..474acf874c8 --- /dev/null +++ b/tests/units/reflex_base/test_registry.py @@ -0,0 +1,133 @@ +"""Tests for RegistrationContext.""" + +import pytest +from reflex_base.registry import RegisteredEventHandler, RegistrationContext +from reflex_base.utils.exceptions import StateValueError + + +def test_ensure_context_creates_if_missing(): + """ensure_context() returns existing context or creates a new one.""" + try: + existing = RegistrationContext._context_var.get() + assert RegistrationContext.ensure_context() is existing + except LookupError: + ctx = RegistrationContext.ensure_context() + assert isinstance(ctx, RegistrationContext) + assert RegistrationContext.get() is ctx + + +def test_clean_context_is_empty(clean_registration_context: RegistrationContext): + """A clean context starts with no handlers or states. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + assert clean_registration_context.event_handlers == {} + assert clean_registration_context.base_states == {} + assert clean_registration_context.base_state_substates == {} + + +def test_register_event_handler(clean_registration_context: RegistrationContext): + """register_event_handler stores the handler keyed by its full name. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.event import EventHandler + + async def my_fn(): + pass + + handler = EventHandler(fn=my_fn) + RegistrationContext.register_event_handler(handler) + assert len(clean_registration_context.event_handlers) == 1 + registered = next(iter(clean_registration_context.event_handlers.values())) + assert isinstance(registered, RegisteredEventHandler) + assert registered.handler is handler + + +def test_register_base_state(clean_registration_context: RegistrationContext): + """BaseState metaclass auto-registers during class definition into the active context. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class AutoRegistered(BaseState): + x: int = 0 + + assert AutoRegistered.get_full_name() in clean_registration_context.base_states + + +def test_duplicate_substate_raises(clean_registration_context: RegistrationContext): + """Registering the same substate twice raises StateValueError. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class DupParent(BaseState): + pass + + class DupChild(DupParent): + pass + + with pytest.raises(StateValueError, match="already registered"): + clean_registration_context._register_base_state(DupChild) + + +def test_get_substates(clean_registration_context: RegistrationContext): + """get_substates returns registered children of a parent. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class GetSubRoot(BaseState): + pass + + class GetSub1(GetSubRoot): + pass + + class GetSub2(GetSubRoot): + pass + + substates = clean_registration_context.get_substates(GetSubRoot) + assert GetSub1 in substates + assert GetSub2 in substates + + +def test_get_substates_by_name(clean_registration_context: RegistrationContext): + """get_substates also works when passed a string full name. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class NamedState(BaseState): + pass + + result = clean_registration_context.get_substates(NamedState.get_full_name()) + assert isinstance(result, set) + + +def test_forked_context_is_independent( + forked_registration_context: RegistrationContext, +): + """Changes to a forked context do not affect the original. + + Args: + forked_registration_context: A deep copy of the current registration context. + """ + from reflex.event import EventHandler + + async def _tmp(): + pass + + handler = EventHandler(fn=_tmp) + RegistrationContext.register_event_handler(handler) + assert len(forked_registration_context.event_handlers) > 0 diff --git a/tests/units/test_app.py b/tests/units/test_app.py index a7ca0524232..f331dd6c8b9 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import functools import io import json @@ -10,15 +11,17 @@ from contextlib import nullcontext as does_not_raise from importlib.util import find_spec from pathlib import Path -from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ClassVar -from unittest.mock import AsyncMock +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, Mock import pytest from pytest_mock import MockerFixture from reflex_base.components.component import Component from reflex_base.constants.state import FIELD_MARKER from reflex_base.event import Event +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor +from reflex_base.registry import RegistrationContext from reflex_base.style import Style from reflex_base.utils import console, exceptions, format from reflex_base.vars.base import computed_var @@ -26,32 +29,27 @@ from reflex_components_core.base.fragment import Fragment from reflex_components_core.core.cond import Cond from reflex_components_radix.themes.typography.text import Text +from sqlalchemy.engine.base import Engine from starlette.applications import Starlette from starlette.datastructures import FormData, Headers, UploadFile from starlette.responses import StreamingResponse +from starlette_admin.auth import AuthProvider import reflex as rx from reflex import AdminDash, constants -from reflex.app import ( - App, - ComponentCallable, - default_overlay_component, - process, - upload, -) +from reflex.app import App, ComponentCallable, default_overlay_component, upload from reflex.environment import environment from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis -from reflex.middleware import HydrateMiddleware +from reflex.istate.manager.token import BaseStateToken from reflex.model import Model from reflex.state import ( BaseState, OnLoadInternalState, RouterData, State, - StateUpdate, - _substate_key, + reload_state_module, ) from .conftest import chdir @@ -59,7 +57,6 @@ from .states.upload import ( ChildFileUploadState, ChunkUploadState, - FileStateBase1, FileUploadState, GrandChildFileUploadState, ) @@ -213,12 +210,16 @@ def test_default_app(app: App): Args: app: The app to test. """ - assert app._middlewares == [HydrateMiddleware()] + assert app._middlewares == [] assert app.style == Style() assert app.admin_dash is None -def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): +def test_multiple_states_error( + monkeypatch: pytest.MonkeyPatch, + test_state: BaseState, + redundant_test_state: BaseState, +): """Test that an error is thrown when multiple classes subclass rx.BaseState. Args: @@ -231,7 +232,9 @@ def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): App() -def test_add_page_default_route(app: App, index_page, about_page): +def test_add_page_default_route( + app: App, index_page: ComponentCallable, about_page: ComponentCallable +): """Test adding a page to an app. Args: @@ -249,7 +252,7 @@ def test_add_page_default_route(app: App, index_page, about_page): assert app._pages.keys() == {"index", "about"} -def test_add_page_set_route(app: App, index_page): +def test_add_page_set_route(app: App, index_page: ComponentCallable): """Test adding a page to an app. Args: @@ -263,7 +266,7 @@ def test_add_page_set_route(app: App, index_page): assert app._pages.keys() == {"test"} -def test_add_page_set_route_dynamic(index_page): +def test_add_page_set_route_dynamic(index_page: ComponentCallable): """Test adding a page with dynamic route variable to an app. Args: @@ -283,7 +286,7 @@ def test_add_page_set_route_dynamic(index_page): assert constants.ROUTER in app._state()._var_dependencies -def test_add_page_set_route_nested(app: App, index_page): +def test_add_page_set_route_nested(app: App, index_page: ComponentCallable): """Test adding a page to an app. Args: @@ -296,7 +299,7 @@ def test_add_page_set_route_nested(app: App, index_page): assert app._unevaluated_pages.keys() == {route} -def test_add_page_invalid_api_route(app: App, index_page): +def test_add_page_invalid_api_route(app: App, index_page: ComponentCallable): """Test adding a page with an invalid route to an app. Args: @@ -371,7 +374,7 @@ def test_add_duplicate_page_route_error(app: App, first_page, second_page, route or not find_spec("pydantic"), reason="starlette_admin not installed or sqlmodel not installed or pydantic not installed", ) -def test_initialize_with_admin_dashboard(test_model): +def test_initialize_with_admin_dashboard(test_model: Model): """Test setting the admin dashboard of an app. Args: @@ -390,9 +393,9 @@ def test_initialize_with_admin_dashboard(test_model): reason="starlette_admin not installed or sqlmodel not installed or pydantic not installed", ) def test_initialize_with_custom_admin_dashboard( - test_get_engine, - test_custom_auth_admin, - test_model_auth, + test_get_engine: Engine, + test_custom_auth_admin: type[AuthProvider], + test_model_auth: Model, ): """Test setting the custom admin dashboard of an app. @@ -452,7 +455,9 @@ async def test_initialize_with_state(test_state: type[ATestState], token: str): assert app._state == test_state # Get a state for a given token. - state = await app.state_manager.get_state(_substate_key(token, test_state)) + state = await app.state_manager.get_state( + BaseStateToken(ident=token, cls=test_state) + ) assert isinstance(state, test_state) assert state.var == 0 @@ -469,8 +474,8 @@ async def test_set_and_get_state(test_state: type[ATestState]): app = App(_state=test_state) # Create two tokens. - token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" - token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" + token1 = BaseStateToken(ident=str(uuid.uuid4()), cls=test_state) + token2 = BaseStateToken(ident=str(uuid.uuid4()), cls=test_state) # Get the default state for each token. state1 = await app.state_manager.get_state(token1) @@ -498,25 +503,37 @@ async def test_set_and_get_state(test_state: type[ATestState]): @pytest.mark.asyncio -async def test_dynamic_var_event(test_state: type[ATestState], token: str): +async def test_dynamic_var_event( + test_state: type[ATestState], + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], + token: str, + clean_registration_context: RegistrationContext, +): """Test that the default handler of a dynamic generated var works as expected. Args: test_state: State Fixture. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + emitted_deltas: List to store emitted deltas. token: a Token. + clean_registration_context: The registration context fixture, which is cleared before each test. """ + clean_registration_context.register_base_state(test_state) state = test_state() # pyright: ignore [reportCallIssue] state.add_var("int_val", int, 0) - async for result in state._process( - Event( - token=token, - name=f"{test_state.get_name()}.set_int_val", - router_data={"pathname": "/", "query": {}}, - payload={"value": 50}, + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + name=f"{test_state.get_name()}.set_int_val", + payload={"value": 50}, + ), ) - ): - assert result.delta == {test_state.get_name(): {"int_val" + FIELD_MARKER: 50}} + assert emitted_deltas == [ + (token, {test_state.get_name(): {"int_val" + FIELD_MARKER: 50}}) + ] @pytest.fixture @@ -691,6 +708,8 @@ async def test_list_mutation_detection__plain_list( event_tuples: list[tuple[str, list[str]]], list_mutation_state: State, token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], ): """Test list mutation detection when reassignment is not explicitly included in the logic. @@ -699,19 +718,22 @@ async def test_list_mutation_detection__plain_list( event_tuples: From parametrization. list_mutation_state: A state with list mutation features. token: a Token. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + emitted_deltas: List to store emitted deltas. """ for event_name, expected_delta in event_tuples: - async for result in list_mutation_state._process( - Event( - token=token, - name=f"{list_mutation_state.get_name()}.{event_name}", - router_data={"pathname": "/", "query": {}}, - payload={}, + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + name=f"{list_mutation_state.get_name()}.{event_name}", + payload={}, + ), ) - ): - # prefix keys in expected_delta with the state name - expected_delta = {list_mutation_state.get_name(): expected_delta} - assert result.delta == expected_delta + # prefix keys in expected_delta with the state name + expected_delta = {list_mutation_state.get_name(): expected_delta} + assert emitted_deltas == [(token, expected_delta)] + emitted_deltas.clear() # Clear emitted deltas for the next iteration @pytest.fixture @@ -883,6 +905,8 @@ async def test_dict_mutation_detection__plain_list( event_tuples: list[tuple[str, list[str]]], dict_mutation_state: State, token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], ): """Test dict mutation detection when reassignment is not explicitly included in the logic. @@ -891,20 +915,22 @@ async def test_dict_mutation_detection__plain_list( event_tuples: From parametrization. dict_mutation_state: A state with dict mutation features. token: a Token. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + emitted_deltas: List to store emitted deltas. """ for event_name, expected_delta in event_tuples: - async for result in dict_mutation_state._process( - Event( - token=token, - name=f"{dict_mutation_state.get_name()}.{event_name}", - router_data={"pathname": "/", "query": {}}, - payload={}, + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + name=f"{dict_mutation_state.get_name()}.{event_name}", + payload={}, + ), ) - ): - # prefix keys in expected_delta with the state name - expected_delta = {dict_mutation_state.get_name(): expected_delta} - - assert result.delta == expected_delta + # prefix keys in expected_delta with the state name + expected_delta = {dict_mutation_state.get_name(): expected_delta} + assert emitted_deltas == [(token, expected_delta)] + emitted_deltas.clear() # Clear emitted deltas for the next iteration @pytest.mark.asyncio @@ -937,7 +963,16 @@ async def test_dict_mutation_detection__plain_list( ), ], ) -async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFixture): +async def test_upload_file( + tmp_path: Path, + state, + delta, + token: str, + mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, + clean_registration_context: RegistrationContext, +): """Test that file upload works correctly. Args: @@ -946,16 +981,18 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFix delta: Expected delta after processing all files. token: a Token. mocker: pytest mocker object. + attached_mock_base_state_event_processor: BaseStateEventProcessor Fixture attached to the app instance to capture emitted events. + mock_root_event_context: The mocked root event context, for accessing state_manager. + clean_registration_context: Fixture to ensure clean registration context for each test, preventing cross-test contamination of state subclasses. """ - mocker.patch( - "reflex.state.State.class_subclasses", - {state if state is FileUploadState else FileStateBase1}, + clean_registration_context.register_base_state(state) + app = Mock( + event_processor=attached_mock_base_state_event_processor, ) - # The App state must be the "root" of the state tree - app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] - async with app.modify_state(_substate_key(token, state)) as root_state: - root_state.get_substate(state.get_full_name().split("."))._tmp_path = tmp_path + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=state) + ) as root_state: + (await root_state.get_state(state))._tmp_path = tmp_path data = b"This is binary data" request_mock = unittest.mock.Mock() @@ -985,22 +1022,20 @@ async def form(): # noqa: RUF029 updates = [] async for state_update in streaming_response.body_iterator: updates.append(json.loads(str(state_update))) - # 2 intermediate yields + 1 final - assert len(updates) == 3 - assert all(not u["final"] for u in updates[:-1]) - assert updates[-1]["final"] + # 2 intermediate yields + assert len(updates) == 2 # The last intermediate update should contain the full cumulative delta. assert updates[1]["delta"] == delta - await app.state_manager.close() - @pytest.mark.asyncio async def test_upload_file_keeps_form_open_until_stream_completes( - tmp_path, + tmp_path: Path, token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that upload files are not eagerly copied into memory. @@ -1012,19 +1047,16 @@ async def test_upload_file_keeps_form_open_until_stream_completes( tmp_path: Temporary path. token: A token. mocker: pytest mocker object. + attached_mock_base_state_event_processor: BaseStateEventProcessor Fixture attached to the app instance to capture emitted events. + mock_root_event_context: The mocked root event context, for accessing state_manager. """ - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) - app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) # Set _tmp_path via modify_state instead of setting class attribute directly. - async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: - root_state.get_substate( - FileUploadState.get_full_name().split(".") - )._tmp_path = tmp_path + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=FileUploadState) + ) as root_state: + (await root_state.get_state(FileUploadState))._tmp_path = tmp_path request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1083,25 +1115,21 @@ async def send(message): # noqa: RUF029 assert (tmp_path / "image1.jpg").read_bytes() == data1 assert (tmp_path / "image2.jpg").read_bytes() == data2 - await app.state_manager.close() - @pytest.mark.asyncio async def test_upload_empty_buffered_request_dispatches_alias_handler( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that empty uploads still dispatch buffered alias handlers.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) - app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: - substate = root_state.get_substate(FileUploadState.get_full_name().split(".")) - substate.img_list = [] + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=FileUploadState) + ) as root_state: + (await root_state.get_state(FileUploadState)).img_list = [] request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1122,34 +1150,29 @@ async def form(): # noqa: RUF029 async for state_update in streaming_response.body_iterator: updates.append(json.loads(str(state_update))) - assert updates[-1]["final"] - + assert len(updates) == 1 + assert updates[0]["delta"] == { + FileUploadState.get_full_name(): {"img_list" + FIELD_MARKER: ["count:0"]} + } if environment.REFLEX_OPLOCK_ENABLED.get(): - await app.state_manager.close() + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, FileUploadState)) - substate = ( - state - if isinstance(state, FileUploadState) - else state.get_substate(FileUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=FileUploadState) ) + substate = await state.get_state(FileUploadState) assert isinstance(substate, FileUploadState) assert substate.img_list == ["count:0"] - await app.state_manager.close() - @pytest.mark.asyncio -async def test_upload_file_closes_form_on_event_creation_cancellation( +async def test_upload_file_closes_form_on_form_error( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, ): """Test that cancellation before form parsing leaves form data untouched.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) - app = App() + app = Mock(event_processor=attached_mock_base_state_event_processor) request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1163,15 +1186,11 @@ async def test_upload_file_closes_form_on_event_creation_cancellation( form_close = AsyncMock(side_effect=original_close) form_data.close = form_close - async def form(): # noqa: RUF029 - return form_data - - async def cancelled_get_state(*_args, **_kwargs): + async def cancelled_form(): await asyncio.sleep(0) raise asyncio.CancelledError - request_mock.form = form - mocker.patch.object(app.state_manager, "get_state", side_effect=cancelled_get_state) + request_mock.form = cancelled_form upload_fn = upload(app) with pytest.raises(asyncio.CancelledError): @@ -1180,27 +1199,62 @@ async def cancelled_get_state(*_args, **_kwargs): assert form_close.await_count == 0 assert not file1.file.closed - await app.state_manager.close() + +@pytest.mark.asyncio +async def test_upload_file_closes_form_on_event_creation_cancellation( + token: str, + mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, +): + """Test that cancellation during event creation closes form data.""" + app = Mock(event_processor=attached_mock_base_state_event_processor) + + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", + } + + bio = io.BytesIO(b"data") + file1 = UploadFile(filename="image1.jpg", file=bio) + form_data = FormData([("files", file1)]) + original_close = form_data.close + form_close = AsyncMock(side_effect=original_close) + form_data.close = form_close + + async def form(): # noqa: RUF029 + return form_data + + request_mock.form = form + + # Patch getlist on the form_data to raise CancelledError during event + # creation (after form is parsed, before streaming begins). + form_data.getlist = Mock(side_effect=asyncio.CancelledError) + + upload_fn = upload(app) + with pytest.raises(asyncio.CancelledError): + await upload_fn(request_mock) + + # Form was parsed, so it should be closed on failure. + assert form_close.await_count == 1 + assert bio.closed @pytest.mark.asyncio async def test_upload_file_closes_form_if_response_cancelled_before_stream_starts( - tmp_path, + tmp_path: Path, token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that response cancellation before iteration still closes form data.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) - app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: - root_state.get_substate( - FileUploadState.get_full_name().split(".") - )._tmp_path = tmp_path + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=FileUploadState) + ) as root_state: + (await root_state.get_state(FileUploadState))._tmp_path = tmp_path request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1245,15 +1299,17 @@ async def send(_message): assert form_close.await_count == 1 assert bio.closed - await app.state_manager.close() - @pytest.mark.asyncio @pytest.mark.parametrize( "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -async def test_upload_file_without_annotation(state, tmp_path, token): +async def test_upload_file_without_annotation( + state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, + tmp_path: Path, + token: str, +): """Test that an error is thrown when there's no param annotated with rx.UploadFile or list[UploadFile]. Args: @@ -1292,7 +1348,11 @@ async def form(): # noqa: RUF029 "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -async def test_upload_file_background(state, tmp_path, token): +async def test_upload_file_background( + state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, + tmp_path: Path, + token: str, +): """Test that an error is thrown handler is a background task. Args: @@ -1386,41 +1446,21 @@ async def stream(): return request_mock -async def _drain_background_tasks(app: App): - """Wait for all background tasks associated with an app. - - Returns: - The gathered background task results. - """ - tasks = tuple(app._background_tasks) - results = await asyncio.gather(*tasks, return_exceptions=True) if tasks else [] - if environment.REFLEX_OPLOCK_ENABLED.get(): - # Redis oplocks can keep completed background-task writes in the local - # lease cache until the manager is closed. - await app.state_manager.close() - return results - - @pytest.mark.asyncio async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( tmp_path, token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that the standard upload endpoint dispatches chunk handlers.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {ChunkUploadState}, - ) - app = App() - mocker.patch( - "reflex.utils.prerequisites.get_and_validate_app", - return_value=SimpleNamespace(app=app), - ) - app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: - substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=ChunkUploadState) + ) as root_state: + substate = await root_state.get_state(ChunkUploadState) substate._tmp_path = tmp_path substate.chunk_records = [] substate.completed_files = [] @@ -1449,17 +1489,16 @@ async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( updates = [] async for state_update in response.body_iterator: updates.append(json.loads(str(state_update))) - assert updates == [{"delta": {}, "events": [], "final": True}] + assert updates == [{}] - task_results = await _drain_background_tasks(app) - assert all(result is None for result in task_results) + await attached_mock_base_state_event_processor.join() + if environment.REFLEX_OPLOCK_ENABLED.get(): + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) - substate = ( - state - if isinstance(state, ChunkUploadState) - else state.get_substate(ChunkUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=ChunkUploadState) ) + substate = await state.get_state(ChunkUploadState) assert isinstance(substate, ChunkUploadState) parsed_chunk_records = [ (filename, int(offset), int(size), content_type) @@ -1496,31 +1535,22 @@ async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( assert substate.completed_files == ["alpha.txt", "beta.txt"] assert (tmp_path / "alpha.txt").read_bytes() == b"abcde" assert (tmp_path / "beta.txt").read_bytes() == b"12345" - assert app.event_namespace.emit_update.await_count >= 1 # pyright: ignore [reportOptionalMemberAccess] - assert not app._background_tasks - - await app.state_manager.close() @pytest.mark.asyncio async def test_upload_empty_chunk_request_dispatches_alias_handler( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that empty uploads still dispatch chunk alias handlers.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {ChunkUploadState}, - ) - app = App() - mocker.patch( - "reflex.utils.prerequisites.get_and_validate_app", - return_value=SimpleNamespace(app=app), - ) - app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: - substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=ChunkUploadState) + ) as root_state: + substate = await root_state.get_state(ChunkUploadState) substate.chunk_records = [] substate.completed_files = [] @@ -1541,44 +1571,37 @@ async def test_upload_empty_chunk_request_dispatches_alias_handler( updates = [] async for state_update in response.body_iterator: updates.append(json.loads(str(state_update))) - assert updates == [{"delta": {}, "events": [], "final": True}] + assert updates == [{}] - task_results = await _drain_background_tasks(app) - assert all(result is None for result in task_results) + await attached_mock_base_state_event_processor.join() + if environment.REFLEX_OPLOCK_ENABLED.get(): + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) - substate = ( - state - if isinstance(state, ChunkUploadState) - else state.get_substate(ChunkUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=ChunkUploadState) ) + substate = await state.get_state(ChunkUploadState) assert isinstance(substate, ChunkUploadState) assert substate.chunk_records == [] assert substate.completed_files == ["chunks:0"] - assert not app._background_tasks - - await app.state_manager.close() @pytest.mark.asyncio async def test_upload_chunk_invalid_offset_returns_400( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that malformed chunk metadata fails the standard upload request.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {ChunkUploadState}, - ) - app = App() - mocker.patch( - "reflex.utils.prerequisites.get_and_validate_app", - return_value=SimpleNamespace(app=app), - ) - app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) + # The background task is expected to fail with a parse error for malformed input. + attached_mock_base_state_event_processor.backend_exception_handler = None - async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: - substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=ChunkUploadState) + ) as root_state: + substate = await root_state.get_state(ChunkUploadState) substate.chunk_records = [] substate.completed_files = [] @@ -1597,23 +1620,20 @@ async def test_upload_chunk_invalid_offset_returns_400( "detail": "Missing boundary in multipart." } - await _drain_background_tasks(app) + await attached_mock_base_state_event_processor.join() + if environment.REFLEX_OPLOCK_ENABLED.get(): + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) - substate = ( - state - if isinstance(state, ChunkUploadState) - else state.get_substate(ChunkUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=ChunkUploadState) ) + substate = await state.get_state(ChunkUploadState) assert isinstance(substate, ChunkUploadState) assert substate.chunk_records == [] assert substate.completed_files == [] - assert not app._background_tasks - - await app.state_manager.close() -class DynamicState(BaseState): +class DynamicState(State): """State class for testing dynamic route var. This is defined at module level because event handlers cannot be addressed @@ -1630,7 +1650,6 @@ class DynamicState(BaseState): is_hydrated: bool = False loaded: int = 0 counter: int = 0 - _app_ref: ClassVar[Any] = None @rx.event def on_load(self): @@ -1651,8 +1670,6 @@ def comp_dynamic(self) -> str: """ return self.dynamic # pyright: ignore[reportAttributeAccessIssue] - on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess] - def test_dynamic_arg_shadow( index_page: ComponentCallable, @@ -1668,7 +1685,6 @@ def test_dynamic_arg_shadow( app_module_mock: Mocked app module. mocker: pytest mocker object. """ - DynamicState._app_ref = None arg_name = "counter" route = f"/test/[{arg_name}]" app = app_module_mock.app = App(_state=DynamicState) @@ -1699,12 +1715,31 @@ def test_multiple_dynamic_args( app.add_page(index_page, route=route2) +@pytest.fixture +def cleanup_dynamic_arg(): + """Fixture to reset DynamicState class vars after each test.""" + yield + with contextlib.suppress(AttributeError): + del State.dynamic # pyright: ignore[reportAttributeAccessIssue] + + State.computed_vars.pop("dynamic", None) + State.vars.pop("dynamic", None) + State._var_dependencies = {} + State._potentially_dirty_states = set() + State._always_dirty_computed_vars = set() + reload_state_module(__name__) + + +@pytest.mark.usefixtures("cleanup_dynamic_arg") @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( index_page: ComponentCallable, token: str, app_module_mock: unittest.mock.Mock, - mocker: MockerFixture, + mock_root_event_context: EventContext, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], + emitted_events: list[tuple[str, tuple[Event, ...]]], ): """Create app with dynamic route var, and simulate navigation. @@ -1715,12 +1750,16 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page: The index page. token: a Token. app_module_mock: Mocked app module. - mocker: pytest mocker object. + mock_root_event_context: Mocked root event context. + mock_base_state_event_processor: Mocked BaseStateEventProcessor. + emitted_deltas: List to store emitted deltas. + emitted_events: List to store emitted events. """ - DynamicState._app_ref = None + OnLoadInternalState._app_ref = None arg_name = "dynamic" route = f"test/[{arg_name}]" - app = app_module_mock.app = App(_state=DynamicState) + app = app_module_mock.app = App() + app._state_manager = mock_root_event_context.state_manager assert app._state is not None assert arg_name not in app._state.vars app.add_page(index_page, route=route, on_load=DynamicState.on_load) @@ -1732,24 +1771,18 @@ async def test_dynamic_route_var_route_change_completed_on_load( } assert constants.ROUTER in app._state()._var_dependencies - substate_token = _substate_key(token, DynamicState) - sid = "mock_sid" - client_ip = "127.0.0.1" - async with app.state_manager.modify_state(substate_token) as state: - state.router_data = {"simulate": "hydrated"} - assert state.dynamic == "" # pyright: ignore[reportAttributeAccessIssue] + substate_token = BaseStateToken(ident=token, cls=DynamicState) exp_vals = ["foo", "foobar", "baz"] def _event(name, val, **kwargs): return Event( - token=kwargs.pop("token", token), name=name, router_data=kwargs.pop( "router_data", { "pathname": "/" + route, "query": {arg_name: val}, - "asPath": "/test/something", + "asPath": f"/test/{val}", }, ), payload=kwargs.pop("payload", {}), @@ -1766,57 +1799,53 @@ def _dynamic_state_event(name, val, **kwargs): prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): on_load_internal = _event( - name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", + name=f"{OnLoadInternalState.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", val=exp_val, ) - exp_router_data = { - "headers": {}, - "ip": client_ip, - "sid": sid, - "token": token, - **on_load_internal.router_data, - } - exp_router = RouterData.from_router_data(exp_router_data) - process_coro = process( - app, - event=on_load_internal, - sid=sid, - headers={}, - client_ip=client_ip, - ) - update = await process_coro.__anext__() - # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)] - assert update == StateUpdate( - delta={ - state.get_name(): { - arg_name + FIELD_MARKER: exp_val, - f"comp_{arg_name}" + FIELD_MARKER: exp_val, - constants.CompileVars.IS_HYDRATED + FIELD_MARKER: False, - "router" + FIELD_MARKER: exp_router, - } - }, - events=[ - _dynamic_state_event( - name="on_load", - val=exp_val, - ), - _event( - name=f"{State.get_name()}.set_is_hydrated", - payload={"value": True}, - val=exp_val, - router_data={}, - ), - ], - ) + exp_router = RouterData.from_router_data(on_load_internal.router_data) + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + on_load_internal, + ) + await processor.join() + assert emitted_deltas == [ + ( + token, + { + State.get_full_name(): { + arg_name + FIELD_MARKER: exp_val, + constants.CompileVars.IS_HYDRATED + FIELD_MARKER: False, + "router" + FIELD_MARKER: exp_router, + }, + DynamicState.get_full_name(): { + f"comp_{arg_name}" + FIELD_MARKER: exp_val, + }, + }, + ), + ( + token, + { + DynamicState.get_full_name(): { + "loaded" + FIELD_MARKER: exp_index + 1, + }, + }, + ), + ( + token, + { + State.get_full_name(): { + "is_hydrated" + FIELD_MARKER: True, + }, + }, + ), + ] + assert emitted_events == [] if isinstance(app.state_manager, StateManagerRedis): # When redis is used, the state is not updated until the processing is complete state = await app.state_manager.get_state(substate_token) assert state.dynamic == prev_exp_val # pyright: ignore[reportAttributeAccessIssue] - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - if environment.REFLEX_OPLOCK_ENABLED.get(): await app.state_manager.close() @@ -1824,122 +1853,84 @@ def _dynamic_state_event(name, val, **kwargs): state = await app.state_manager.get_state(substate_token) assert state.dynamic == exp_val # pyright: ignore[reportAttributeAccessIssue] - process_coro = process( - app, - event=_dynamic_state_event(name="on_load", val=exp_val), - sid=sid, - headers={}, - client_ip=client_ip, - ) - on_load_update = await process_coro.__anext__() - assert on_load_update == StateUpdate( - delta={ - state.get_name(): { - "loaded" + FIELD_MARKER: exp_index + 1, - }, - }, - events=[], - ) - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - process_coro = process( - app, - event=_dynamic_state_event( - name="set_is_hydrated", payload={"value": True}, val=exp_val - ), - sid=sid, - headers={}, - client_ip=client_ip, - ) - on_set_is_hydrated_update = await process_coro.__anext__() - assert on_set_is_hydrated_update == StateUpdate( - delta={ - state.get_name(): { - "is_hydrated" + FIELD_MARKER: True, - }, - }, - events=[], - ) - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - # a simple state update event should NOT trigger on_load or route var side effects - process_coro = process( - app, - event=_dynamic_state_event(name="on_counter", val=exp_val), - sid=sid, - headers={}, - client_ip=client_ip, - ) - update = await process_coro.__anext__() - assert update == StateUpdate( - delta={ - state.get_name(): { - "counter" + FIELD_MARKER: exp_index + 1, - } - }, - events=[], - ) - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - + emitted_deltas.clear() + emitted_events.clear() + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + _dynamic_state_event(name="on_counter", val=exp_val), + ) + assert emitted_deltas == [ + ( + token, + { + DynamicState.get_full_name(): { + "counter" + FIELD_MARKER: exp_index + 1, + } + }, + ) + ] + assert emitted_events == [] + emitted_deltas.clear() + emitted_events.clear() prev_exp_val = exp_val if environment.REFLEX_OPLOCK_ENABLED.get(): await app.state_manager.close() state = await app.state_manager.get_state(substate_token) - assert isinstance(state, DynamicState) - assert state.loaded == len(exp_vals) - assert state.counter == len(exp_vals) + assert isinstance(state, State) + dynamic_state = await state.get_state(DynamicState) + assert isinstance(dynamic_state, DynamicState) + assert dynamic_state.loaded == len(exp_vals) + assert dynamic_state.counter == len(exp_vals) await app.state_manager.close() @pytest.mark.asyncio -async def test_process_events(mocker: MockerFixture, token: str): +async def test_process_events( + token: str, + app_module_mock: unittest.mock.Mock, + mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], +): """Test that an event is processed properly and that it is postprocessed n+1 times. Also check that the processing flag of the last stateupdate is set to False. Args: - mocker: mocker object. token: a Token. + app_module_mock: The mock for the app module, used to patch the app instance. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + mock_root_event_context: The mock for the root event context, used to patch the app + state manager. + emitted_deltas: List to store emitted deltas. """ - router_data = { - "pathname": "/", - "query": {}, - "token": token, - "sid": "mock_sid", - "headers": {}, - "ip": "127.0.0.1", - } - app = App(_state=GenState) - - mocker.patch.object(app, "_postprocess", AsyncMock()) event = Event( - token=token, name=f"{GenState.get_name()}.go", payload={"c": 5}, - router_data=router_data, + router_data={}, ) - async with app.state_manager.modify_state(event.substate_token) as state: - state.router_data = {"simulate": "hydrated"} - async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): - pass + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + event, + ) if environment.REFLEX_OPLOCK_ENABLED.get(): - await app.state_manager.close() + await mock_root_event_context.state_manager.close() - gen_state = await app.state_manager.get_state(event.substate_token) + gen_state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=GenState), + ) assert isinstance(gen_state, GenState) assert gen_state.value == 5 - assert app._postprocess.call_count == 6 # pyright: ignore [reportAttributeAccessIssue] + assert len(emitted_deltas) == 5 - await app.state_manager.close() + await mock_root_event_context.state_manager.close() @pytest.mark.parametrize( @@ -2000,7 +1991,7 @@ def test_overlay_component( @pytest.fixture -def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: +def compilable_app(tmp_path: Path) -> Generator[tuple[App, Path], None, None]: """Fixture for an app that can be compiled. Args: @@ -2038,7 +2029,9 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: [True, False], ) def test_app_wrap_compile_theme( - react_strict_mode: bool, compilable_app: tuple[App, Path], mocker + react_strict_mode: bool, + compilable_app: tuple[App, Path], + mocker: MockerFixture, ): """Test that the radix theme component wraps the app. @@ -2090,7 +2083,9 @@ def test_app_wrap_compile_theme( [True, False], ) def test_app_wrap_priority( - react_strict_mode: bool, compilable_app: tuple[App, Path], mocker + react_strict_mode: bool, + compilable_app: tuple[App, Path], + mocker: MockerFixture, ): """Test that the app wrap components are wrapped in the correct order. @@ -2497,7 +2492,7 @@ class Sub(Base): app._event_namespace = AsyncMock() async with app.modify_state( - token=_substate_key(token, Sub.get_name()) + token=BaseStateToken(ident=token, cls=Sub) ) as root_state: sub = root_state.substates[Sub.get_name()] if substate: diff --git a/tests/units/test_event.py b/tests/units/test_event.py index be2482ccdda..31c750b57dd 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -43,8 +43,7 @@ def make_timeout_logger() -> EventChainVar: def test_create_event(): """Test creating an event.""" - event = Event(token="token", name="state.do_thing", payload={"arg": "value"}) - assert event.token == "token" + event = Event(name="state.do_thing", payload={"arg": "value"}) assert event.name == "state.do_thing" assert event.payload == {"arg": "value"} @@ -57,7 +56,7 @@ def test_fn(): test_fn.__qualname__ = "test_fn" - def fn_with_args(_, arg1, arg2): + def fn_with_args(arg1, arg2): pass fn_with_args.__qualname__ = "fn_with_args" @@ -112,7 +111,7 @@ def fn_with_args(_, arg1, arg2): def test_call_event_handler_partial(): """Calling an EventHandler with incomplete args returns an EventSpec that can be extended.""" - def fn_with_args(_, arg1, arg2): + def fn_with_args(arg1, arg2): pass fn_with_args.__qualname__ = "fn_with_args" @@ -120,7 +119,7 @@ def fn_with_args(_, arg1, arg2): def spec(a2: Var[str]) -> list[Var[str]]: return [a2] - handler = EventHandler(fn=fn_with_args, state_full_name="BigState") + handler = EventHandler(fn=fn_with_args) event_spec = handler(make_var("first")) event_spec2 = call_event_handler(event_spec, spec) @@ -129,8 +128,7 @@ def spec(a2: Var[str]) -> list[Var[str]]: assert event_spec.args[0][0].equals(Var(_js_expr="arg1")) assert event_spec.args[0][1].equals(Var(_js_expr="first")) assert ( - format.format_event(event_spec) - == 'ReflexEvent("BigState.fn_with_args", {arg1:first})' + format.format_event(event_spec) == 'ReflexEvent("fn_with_args", {arg1:first})' ) assert event_spec2 is not event_spec @@ -142,7 +140,7 @@ def spec(a2: Var[str]) -> list[Var[str]]: assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str)) assert ( format.format_event(event_spec2) - == 'ReflexEvent("BigState.fn_with_args", {arg1:first,arg2:_a2})' + == 'ReflexEvent("fn_with_args", {arg1:first,arg2:_a2})' ) @@ -162,16 +160,15 @@ def test_fix_events(arg1, arg2): arg2: The second arg passed to the handler. """ - def fn_with_args(_, arg1, arg2): + def fn_with_args(arg1, arg2): pass fn_with_args.__qualname__ = "fn_with_args" handler = EventHandler(fn=fn_with_args) event_spec = handler(arg1, arg2) - event = fix_events([event_spec], token="foo")[0] + event = fix_events([event_spec])[0] assert event.name == fn_with_args.__qualname__ - assert event.token == "foo" assert event.payload == {"arg1": arg1, "arg2": arg2} diff --git a/tests/units/test_model.py b/tests/units/test_model.py index 1ca665207d4..82b61a57c22 100644 --- a/tests/units/test_model.py +++ b/tests/units/test_model.py @@ -4,6 +4,7 @@ import pytest from reflex_base.constants.state import FIELD_MARKER +from reflex_base.event import Event import reflex.constants import reflex.model @@ -221,25 +222,37 @@ def rx_model(self, m: ReflexModel): # noqa: D102 @pytest.mark.asyncio -@pytest.mark.usefixtures("mock_app_simple") @pytest.mark.parametrize( ("handler", "payload"), [ (UpcastStateWithSqlAlchemy.rx_model, {"m": {"foo": "bar"}}), ], ) -async def test_upcast_event_handler_arg(handler, payload): +async def test_upcast_event_handler_arg( + handler, payload, mock_base_state_event_processor, emitted_deltas +): """Test that upcast event handler args work correctly. Args: handler: The handler to test. payload: The payload to test. + mock_base_state_event_processor: Fixture for processing events with a BaseState. + emitted_deltas: List to store emitted deltas. """ - state = UpcastStateWithSqlAlchemy() - async for update in state._process_event(handler, state, payload): - assert update.delta == { - UpcastStateWithSqlAlchemy.get_full_name(): {"passed" + FIELD_MARKER: True} - } + async with mock_base_state_event_processor as processor: + await processor.enqueue( + "test_token", Event.from_event_type(handler(**payload))[0] + ) + assert emitted_deltas == [ + ( + "test_token", + { + UpcastStateWithSqlAlchemy.get_full_name(): { + "passed" + FIELD_MARKER: True + } + }, + ), + ] def test_no_rebind_mutable_proxy_for_instrumented_functions(): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 3ad9360b385..a61add0a435 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,7 @@ import os import sys import threading -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable, Mapping from textwrap import dedent from typing import Any, ClassVar from unittest.mock import AsyncMock, Mock @@ -22,9 +22,11 @@ from pydantic import BaseModel as Base from pytest_mock import MockerFixture from reflex_base import constants -from reflex_base.constants import CompileVars, RouteVar, SocketEvent +from reflex_base.constants import CompileVars, RouteVar from reflex_base.constants.state import FIELD_MARKER from reflex_base.event import Event, EventHandler +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor from reflex_base.utils import format, types from reflex_base.utils.exceptions import ( InvalidLockWarningThresholdError, @@ -45,6 +47,8 @@ from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis +from reflex.istate.manager.token import BaseStateToken +from reflex.istate.proxy import StateProxy from reflex.state import ( BaseState, ImmutableMutableProxy, @@ -53,13 +57,9 @@ OnLoadInternalState, RouterData, State, - StateProxy, - StateUpdate, - _substate_key, ) from reflex.testing import chdir from reflex.utils import prerequisites -from reflex.utils.token_manager import SocketRecord from tests.units.mock_redis import mock_redis from .states import GenState @@ -807,101 +807,117 @@ def test_reset(test_state: TestState, child_state: ChildState): @pytest.mark.asyncio -async def test_process_event_simple(test_state): +async def test_process_event_simple( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): """Test processing an event. Args: - test_state: A state. + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ - assert test_state.num1 == 0 - - event = Event(token="t", name="set_num1", payload={"value": 69}) - async for update in test_state._process(event): - # The event should update the value. - assert test_state.num1 == 69 - - # The delta should contain the changes, including computed vars. - assert update.delta == { - TestState.get_full_name(): { - "num1" + FIELD_MARKER: 69, - "sum" + FIELD_MARKER: 72.15, + event = Event( + name=f"{TestState.get_full_name()}.set_num1", + payload={"value": 69}, + ) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + # The delta should contain the changes, including computed vars. + assert emitted_deltas == [ + ( + token, + { + TestState.get_full_name(): { + "num1" + FIELD_MARKER: 69, + "sum" + FIELD_MARKER: 72.15, + }, + GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, }, - GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, - } - assert update.events == [] + ) + ] @pytest.mark.asyncio -async def test_process_event_substate(test_state, child_state, grandchild_state): +async def test_process_event_substate( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): """Test processing an event on a substate. Args: - test_state: A state. - child_state: A child state. - grandchild_state: A grandchild state. + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ # Events should bubble down to the substate. - assert child_state.value == "" - assert child_state.count == 23 event = Event( - token="t", - name=f"{ChildState.get_name()}.change_both", + name=f"{ChildState.get_full_name()}.change_both", payload={"value": "hi", "count": 12}, ) - async for update in test_state._process(event): - assert child_state.value == "HI" - assert child_state.count == 24 - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - ChildState.get_full_name(): { - "value" + FIELD_MARKER: "HI", - "count" + FIELD_MARKER: 24, + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + assert emitted_deltas == [ + ( + token, + { + ChildState.get_full_name(): { + "value" + FIELD_MARKER: "HI", + "count" + FIELD_MARKER: 24, + }, + GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, }, - GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, - } - test_state._clean() + ) + ] + emitted_deltas.clear() # Test with the grandchild state. - assert grandchild_state.value2 == "" event = Event( - token="t", name=f"{GrandchildState.get_full_name()}.set_value2", payload={"value": "new"}, ) - async for update in test_state._process(event): - assert grandchild_state.value2 == "new" - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - GrandchildState.get_full_name(): {"value2" + FIELD_MARKER: "new"}, - GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, - } + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + assert emitted_deltas == [ + ( + token, + { + GrandchildState.get_full_name(): {"value2" + FIELD_MARKER: "new"}, + GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, + }, + ) + ] @pytest.mark.asyncio -async def test_process_event_generator(): - """Test event handlers that generate multiple updates.""" - gen_state = GenState() # pyright: ignore [reportCallIssue] +async def test_process_event_generator( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): + """Test event handlers that generate multiple updates. + + Args: + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. + """ event = Event( - token="t", - name="go", + name=f"{GenState.get_full_name()}.go", payload={"c": 5}, ) - gen = gen_state._process(event) - - count = 0 - async for update in gen: - count += 1 - if count == 6: - assert update.delta == {} - assert update.final - else: - assert gen_state.value == count - assert update.delta == { - GenState.get_full_name(): {"value" + FIELD_MARKER: count}, - } - assert not update.final - - assert count == 6 + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + # Generator yields 5 deltas (one per increment). + assert len(emitted_deltas) == 5 + for count, (delta_token, delta) in enumerate(emitted_deltas, 1): + assert delta_token == token + assert delta == { + GenState.get_full_name(): {"value" + FIELD_MARKER: count}, + } def test_get_client_token(test_state, router_data): @@ -1644,12 +1660,15 @@ def reset(self): @pytest.mark.asyncio -async def test_state_with_invalid_yield(capsys: pytest.CaptureFixture[str], mock_app): +async def test_state_with_invalid_yield( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, +): """Test that an error is thrown when a state yields an invalid value. Args: - capsys: Pytest fixture for capture standard streams. - mock_app: Mock app fixture. + token: A token. + mock_base_state_event_processor: The event processor. """ class StateWithInvalidYield(BaseState): @@ -1663,60 +1682,29 @@ def invalid_handler(self): """ yield 1 - invalid_state = StateWithInvalidYield() - async for update in invalid_state._process( - rx.event.Event(token="fake_token", name="invalid_handler") - ): - assert not update.delta - assert update.events == rx.event.fix_events( - [ - rx.toast( - "An error occurred.", - level="error", - fallback_to_alert=True, - description="TypeError: Your handler test_state_with_invalid_yield..StateWithInvalidYield.invalid_handler must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references). Returned events of types ..
See logs for details.", - id="backend_error", - position="top-center", - style={"width": "500px"}, - ) - ], - token="", - ) - captured = capsys.readouterr() - assert "must only return/yield: None, Events or other EventHandlers" in captured.err - + captured_exceptions: list[Exception] = [] -@pytest_asyncio.fixture( - loop_scope="function", scope="function", params=["in_process", "disk", "redis"] -) -async def state_manager(request) -> AsyncGenerator[StateManager, None]: - """Instance of state manager parametrized for redis and in-process. + def capture_exception(ex: Exception) -> None: + captured_exceptions.append(ex) - Args: - request: pytest request object. - - Yields: - A state manager instance - """ - state_manager = StateManager.create(state=TestState) - if request.param == "redis": - if not isinstance(state_manager, StateManagerRedis): - state_manager = StateManagerRedis(state=TestState, redis=mock_redis()) - elif request.param == "disk": - # explicitly NOT using redis - state_manager = StateManagerDisk(state=TestState) - assert not state_manager._states_locks - else: - state_manager = StateManagerMemory(state=TestState) - assert not state_manager._states_locks + mock_base_state_event_processor.backend_exception_handler = capture_exception - yield state_manager + event = Event( + name=f"{StateWithInvalidYield.get_full_name()}.invalid_handler", + payload={}, + ) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) - await state_manager.close() + assert len(captured_exceptions) == 1 + assert isinstance(captured_exceptions[0], TypeError) + assert "must only return/yield: None, Events or other EventHandlers" in str( + captured_exceptions[0] + ) @pytest.fixture -def substate_token(state_manager, token) -> str: +def substate_token(state_manager, token) -> BaseStateToken: """A token + substate name for looking up in state manager. Args: @@ -1726,12 +1714,12 @@ def substate_token(state_manager, token) -> str: Returns: Token concatenated with the state_manager's state full_name. """ - return _substate_key(token, state_manager.state) + return BaseStateToken(ident=token, cls=TestState) @pytest.mark.asyncio async def test_state_manager_modify_state( - state_manager: StateManager, token: str, substate_token: str + state_manager: StateManager, token: str, substate_token: BaseStateToken ): """Test that the state manager can modify a state exclusively. @@ -1759,10 +1747,11 @@ async def test_state_manager_modify_state( if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - assert not state_manager._states_locks[token].locked() + lock = state_manager._states_locks.get(token) + assert lock is None or not lock.locked() # separate instances should NOT share locks - sm2 = type(state_manager)(state=TestState) + sm2 = type(state_manager)() assert sm2._state_manager_lock is state_manager._state_manager_lock assert not sm2._states_locks if state_manager._states_locks: @@ -1773,7 +1762,7 @@ async def test_state_manager_modify_state( @pytest.mark.asyncio async def test_state_manager_contend( - state_manager: StateManager, token: str, substate_token: str + state_manager: StateManager, token: str, substate_token: BaseStateToken ): """Multiple coroutines attempting to access the same state. @@ -1809,8 +1798,77 @@ async def _coro(): if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - assert token in state_manager._states_locks - assert not state_manager._states_locks[token].locked() + lock = state_manager._states_locks.get(token) + assert lock is None or not lock.locked() + + +@pytest.mark.asyncio +async def test_state_manager_legacy_token(state_manager: StateManager, token: str): + """Test that passing a legacy string token to the state manager works with a deprecation warning. + + Args: + state_manager: A state manager instance. + token: A token. + """ + from unittest.mock import patch + + import reflex_base.utils.console as _base_console + + from reflex.istate.manager import token as _token_mod + + console = _token_mod.console + + from reflex.state import State + + legacy_token = f"{token}_{OnLoadState.get_full_name()}" + + def _clear_dedupe(): + _base_console._EMITTED_DEPRECATION_WARNINGS -= { + k + for k in _base_console._EMITTED_DEPRECATION_WARNINGS + if "Passing a string to modify_state" in k + } + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + # modify_state should accept a legacy string token and emit a deprecation warning. + async with state_manager.modify_state(legacy_token) as state: + assert isinstance(state, State) + # The substate targeted by the token should be prepopulated. + assert OnLoadState.get_name() in state.substates + mock_deprecate.assert_called() + assert ( + mock_deprecate.call_args.kwargs["feature_name"] + == "Passing a string to modify_state" + ) + mock_deprecate.reset_mock() + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + # get_state should also accept a legacy string token. + retrieved = await state_manager.get_state(legacy_token) + assert isinstance(retrieved, State) + assert OnLoadState.get_name() in retrieved.substates + mock_deprecate.assert_called() + mock_deprecate.reset_mock() + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + # set_state should also accept a legacy string token. + await state_manager.set_state(legacy_token, retrieved) + mock_deprecate.assert_called() + mock_deprecate.reset_mock() + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + final = await state_manager.get_state(legacy_token) + assert isinstance(final, State) + assert OnLoadState.get_name() in final.substates + mock_deprecate.assert_called() @pytest_asyncio.fixture(loop_scope="function", scope="function") @@ -1820,11 +1878,11 @@ async def state_manager_redis() -> AsyncGenerator[StateManager, None]: Yields: A state manager instance """ - state_manager = StateManager.create(TestState) + state_manager = StateManager.create() if not isinstance(state_manager, StateManagerRedis): # Create a mocked redis client instead of skipping. - state_manager = StateManagerRedis(state=TestState, redis=mock_redis()) + state_manager = StateManagerRedis(redis=mock_redis()) yield state_manager @@ -1842,12 +1900,14 @@ def substate_token_redis(state_manager_redis, token): Returns: Token concatenated with the state_manager's state full_name. """ - return _substate_key(token, state_manager_redis.state) + return BaseStateToken(ident=token, cls=TestState) @pytest.mark.asyncio async def test_state_manager_lock_expire( - state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, + token: str, + substate_token_redis: BaseStateToken, ): """Test that the state manager lock expires and raises exception exiting context. @@ -1858,6 +1918,7 @@ async def test_state_manager_lock_expire( """ state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + state_manager_redis.oplock_hold_time_ms = LOCK_EXPIRATION // 2 loop_exception = None @@ -1892,7 +1953,9 @@ def loop_exception_handler(loop, context): @pytest.mark.asyncio async def test_state_manager_lock_expire_contend( - state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, + token: str, + substate_token_redis: BaseStateToken, ): """Test that the state manager lock expires and queued waiters proceed. @@ -1906,6 +1969,7 @@ async def test_state_manager_lock_expire_contend( state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + state_manager_redis.oplock_hold_time_ms = LOCK_EXPIRATION // 2 loop_exception = None @@ -1968,7 +2032,7 @@ async def _coro_waiter(): async def test_state_manager_lock_warning_threshold_contend( state_manager_redis: StateManagerRedis, token: str, - substate_token_redis: str, + substate_token_redis: BaseStateToken, mocker: MockerFixture, ): """Test that the state manager triggers a warning when lock contention exceeds the warning threshold. @@ -2116,14 +2180,20 @@ def from_dict(cls, data: dict) -> ModelDC: @pytest.mark.asyncio async def test_state_proxy( - grandchild_state: GrandchildState, mock_app: rx.App, token: str + grandchild_state: GrandchildState, + token: str, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], + attached_mock_event_context: EventContext, ): """Test that the state proxy works. Args: grandchild_state: A grandchild state. - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_base_state_event_processor: The event processor attached for this test. + emitted_deltas: A list to capture emitted deltas. + attached_mock_event_context: The event context attached for this test. """ child_state = grandchild_state.parent_state assert child_state is not None @@ -2135,30 +2205,26 @@ async def test_state_proxy( "sid": "test_sid", }) grandchild_state.router = router_data - namespace = mock_app.event_namespace - assert namespace is not None - namespace.sid_to_token[router_data.session.session_id] = token - namespace._token_manager.instance_id = "mock" - namespace._token_manager.token_to_socket[token] = SocketRecord( - instance_id="mock", sid=router_data.session.session_id - ) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): - mock_app.state_manager.states[parent_state.router.session.client_token] = ( - parent_state - ) - elif isinstance(mock_app.state_manager, StateManagerRedis): + state_manager = attached_mock_event_context.state_manager + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): + state_manager.states[parent_state.router.session.client_token] = parent_state + elif isinstance(state_manager, StateManagerRedis): pickle_state = parent_state._serialize() if pickle_state: - await mock_app.state_manager.redis.set( - _substate_key(parent_state.router.session.client_token, parent_state), + await state_manager.redis.set( + str( + BaseStateToken( + ident=parent_state.router.session.client_token, + cls=type(parent_state), + ) + ), pickle_state, - ex=mock_app.state_manager.token_expiration, + ex=state_manager.token_expiration, ) sp = StateProxy(grandchild_state) assert sp.__wrapped__ == grandchild_state assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split(".")) - assert sp._self_app is mock_app assert not sp._self_mutable assert sp._self_actx is None @@ -2189,7 +2255,7 @@ async def test_state_proxy( async with sp: assert sp._self_actx is not None assert sp._self_mutable # proxy is mutable inside context - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists assert sp.__wrapped__ is grandchild_state else: @@ -2201,13 +2267,16 @@ async def test_state_proxy( assert sp.value2 == "42" if environment.REFLEX_OPLOCK_ENABLED.get(): - await mock_app.state_manager.close() + await state_manager.close() # Get the state from the state manager directly and check that the value is updated - gotten_state = await mock_app.state_manager.get_state( - _substate_key(grandchild_state.router.session.client_token, grandchild_state) + gotten_state = await state_manager.get_state( + BaseStateToken( + ident=grandchild_state.router.session.client_token, + cls=type(grandchild_state), + ) ) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists assert gotten_state is parent_state else: @@ -2218,23 +2287,21 @@ async def test_state_proxy( assert gotten_grandchild_state.value2 == "42" # ensure state update was emitted - assert mock_app.event_namespace is not None - mock_app.event_namespace.emit.assert_called_once() # pyright: ignore [reportAttributeAccessIssue] - mcall = mock_app.event_namespace.emit.mock_calls[0] # pyright: ignore [reportAttributeAccessIssue] - assert mcall.args[0] == str(SocketEvent.EVENT) - assert mcall.args[1] == StateUpdate( - delta={ - TestState.get_full_name(): {"router" + FIELD_MARKER: router_data}, - grandchild_state.get_full_name(): { - "value2" + FIELD_MARKER: "42", - }, - GrandchildState3.get_full_name(): { - "computed" + FIELD_MARKER: "", + await attached_mock_base_state_event_processor.join(timeout=1) + assert emitted_deltas == [ + ( + token, + { + TestState.get_full_name(): {"router" + FIELD_MARKER: router_data}, + grandchild_state.get_full_name(): { + "value2" + FIELD_MARKER: "42", + }, + GrandchildState3.get_full_name(): { + "computed" + FIELD_MARKER: "", + }, }, - }, - final=None, - ) - assert mcall.kwargs["to"] == grandchild_state.router.session.session_id + ) + ] class BackgroundTaskState(BaseState): @@ -2244,10 +2311,6 @@ class BackgroundTaskState(BaseState): dict_list: dict[str, list[int]] = {"foo": [1, 2, 3]} dc: ModelDC = ModelDC() - def __init__(self, **kwargs): # noqa: D107 - super().__init__(**kwargs) - self.router_data = {"simulate": "hydrate"} - @rx.var(cache=False) def computed_order(self) -> list[str]: """Get the order as a computed var. @@ -2344,79 +2407,44 @@ async def bad_chain2(self): @pytest.mark.asyncio -async def test_background_task_no_block(mock_app: rx.App, token: str): +async def test_background_task_no_block( + mock_app: rx.App, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, + state_manager: StateManager, +): """Test that a background task does not block other events. Args: mock_app: An app that will be returned by `get_app()` token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. + state_manager: A state manager instance. """ - router_data = {"query": {}, "token": token} - sid = "test_sid" - namespace = mock_app.event_namespace - assert namespace is not None - namespace.sid_to_token[sid] = token - namespace._token_manager.instance_id = "mock" - namespace._token_manager.token_to_socket[token] = SocketRecord( - instance_id="mock", sid=sid - ) - mock_app.state_manager.state = mock_app._state = BackgroundTaskState - async for update in rx.app.process( - mock_app, - Event( - token=token, - name=f"{BackgroundTaskState.get_full_name()}.background_task", - router_data=router_data, - payload={}, - ), - sid=sid, - headers={}, - client_ip="", - ): - # background task returns empty update immediately - assert update == StateUpdate() - - # wait for the coroutine to start - await asyncio.sleep(0.5 if CI else 0.1) - assert len(mock_app._background_tasks) == 1 - - # Process another normal event - async for update in rx.app.process( - mock_app, - Event( - token=token, - name=f"{BackgroundTaskState.get_full_name()}.other", - router_data=router_data, - payload={}, - ), - sid=sid, - headers={}, - client_ip="", - ): - # other task returns delta - assert update == StateUpdate( - delta={ - BackgroundTaskState.get_full_name(): { - "order" + FIELD_MARKER: [ - "background_task:start", - "other", - ], - "computed_order" + FIELD_MARKER: [ - "background_task:start", - "other", - ], - } - }, + async with mock_base_state_event_processor as processor: + # Start background task + await processor.enqueue( + token, + Event( + name=f"{BackgroundTaskState.get_full_name()}.background_task", + payload={}, + ), + ) + # Wait for the background task coroutine to start + await asyncio.sleep(0.5 if CI else 0.1) + + # Process another normal event while background task is polling + await processor.enqueue( + token, + Event( + name=f"{BackgroundTaskState.get_full_name()}.other", + payload={}, + ), ) - # Explicit wait for background tasks - for task in tuple(mock_app._background_tasks): - await task - assert not mock_app._background_tasks - - if environment.REFLEX_OPLOCK_ENABLED.get(): - await mock_app.state_manager.close() - + # After processor context exits, all tasks including background are done. exp_order = [ "background_task:start", "other", @@ -2425,98 +2453,45 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "private", ] - background_task_state = await mock_app.state_manager.get_state( - _substate_key(token, BackgroundTaskState) + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + + background_task_state = await state_manager.get_state( + BaseStateToken(ident=token, cls=BackgroundTaskState) ) assert isinstance(background_task_state, BackgroundTaskState) assert background_task_state.order == exp_order - assert mock_app.event_namespace is not None - emit_mock = mock_app.event_namespace.emit - - first_ws_message = emit_mock.mock_calls[0].args[1] # pyright: ignore [reportAttributeAccessIssue] - assert ( - first_ws_message.delta[BackgroundTaskState.get_full_name()].pop( - "router" + FIELD_MARKER - ) - is not None - ) - assert first_ws_message == StateUpdate( - delta={ - BackgroundTaskState.get_full_name(): { - "order" + FIELD_MARKER: ["background_task:start"], - "computed_order" + FIELD_MARKER: ["background_task:start"], - } - }, - events=[], - final=None, - ) - for call in emit_mock.mock_calls[1:5]: # pyright: ignore [reportAttributeAccessIssue] - assert call.args[1] == StateUpdate( - delta={ - BackgroundTaskState.get_full_name(): { - "computed_order" + FIELD_MARKER: ["background_task:start"], - } - }, - events=[], - final=None, - ) - assert emit_mock.mock_calls[-2].args[1] == StateUpdate( # pyright: ignore [reportAttributeAccessIssue] - delta={ - BackgroundTaskState.get_full_name(): { - "order" + FIELD_MARKER: exp_order, - "computed_order" + FIELD_MARKER: exp_order, - "dict_list" + FIELD_MARKER: {}, - } - }, - events=[], - final=None, - ) - assert emit_mock.mock_calls[-1].args[1] == StateUpdate( # pyright: ignore [reportAttributeAccessIssue] - delta={ - BackgroundTaskState.get_full_name(): { - "computed_order" + FIELD_MARKER: exp_order, - }, - }, - events=[], - final=None, - ) @pytest.mark.asyncio -async def test_background_task_reset(mock_app: rx.App, token: str): +async def test_background_task_reset( + mock_app: rx.App, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + state_manager: StateManager, +): """Test that a background task calling reset is protected by the state proxy. Args: mock_app: An app that will be returned by `get_app()` token: A token. + mock_base_state_event_processor: The event processor. + state_manager: A state manager instance. """ - router_data = {"query": {}} - mock_app.state_manager.state = mock_app._state = BackgroundTaskState - async for update in rx.app.process( - mock_app, - Event( - token=token, - name=f"{BackgroundTaskState.get_name()}.background_task_reset", - router_data=router_data, - payload={}, - ), - sid="", - headers={}, - client_ip="", - ): - # background task returns empty update immediately - assert update == StateUpdate() - - # Explicit wait for background tasks - for task in tuple(mock_app._background_tasks): - await task - assert not mock_app._background_tasks + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + name=f"{BackgroundTaskState.get_full_name()}.background_task_reset", + payload={}, + ), + ) if environment.REFLEX_OPLOCK_ENABLED.get(): - await mock_app.state_manager.close() + await state_manager.close() - background_task_state = await mock_app.state_manager.get_state( - _substate_key(token, BackgroundTaskState) + background_task_state = await state_manager.get_state( + BaseStateToken(ident=token, cls=BackgroundTaskState) ) assert isinstance(background_task_state, BackgroundTaskState) assert background_task_state.order == ["reset"] @@ -2803,6 +2778,7 @@ def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable) assert not isinstance(var_copy, MutableProxy) +@pytest.mark.usefixtures("forked_registration_context") def test_duplicate_substate_class(mocker: MockerFixture): # Neuter pytest escape hatch, because we want to test duplicate detection. mocker.patch("reflex.state.is_testing_env", return_value=False) @@ -2998,7 +2974,7 @@ class BaseFieldSetterState(BaseState): assert "c2" in bfss.dirty_vars -def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> dict[str, Any]: +def exp_is_hydrated(state: type[BaseState], is_hydrated: bool = True) -> dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. Args: @@ -3030,18 +3006,21 @@ class OnLoadState2(State): num: int = 0 name: str + @rx.event def test_handler(self): """Test handler that calls another handler. - Returns: - Chain of EventHandlers + Yields: + EventHandler to change name. """ self.num += 1 - return self.change_name + yield type(self).change_name + yield type(self).change_name("other") - def change_name(self): + @rx.event + def change_name(self, name: str = "default"): """Test handler to change name.""" - self.name = "random" + self.name = name class OnLoadState3(State): @@ -3058,13 +3037,40 @@ async def test_handler(self): @pytest.mark.parametrize( ("test_state", "expected"), [ - (OnLoadState, {"on_load_state": {"num": 1}}), - (OnLoadState2, {"on_load_state2": {"num": 1}}), - (OnLoadState3, {"on_load_state3": {"num": 1}}), + ( + OnLoadState, + [ + {OnLoadState.get_full_name(): {"num" + FIELD_MARKER: 1}}, + exp_is_hydrated(State, True), + ], + ), + ( + OnLoadState2, + [ + {OnLoadState2.get_full_name(): {"num" + FIELD_MARKER: 1}}, + exp_is_hydrated(State, True), + {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "default"}}, + {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "other"}}, + ], + ), + ( + OnLoadState3, + [ + {OnLoadState3.get_full_name(): {"num" + FIELD_MARKER: 1}}, + exp_is_hydrated(State, True), + ], + ), ], ) async def test_preprocess( - app_module_mock, token, test_state, expected, mocker: MockerFixture + app_module_mock, + token, + test_state, + expected, + mocker: MockerFixture, + mock_root_event_context: EventContext, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, ): """Test that a state hydrate event is processed correctly. @@ -3074,12 +3080,13 @@ async def test_preprocess( test_state: State to process event. expected: Expected delta. mocker: pytest mock object. + mock_root_event_context: The mock root event context. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ OnLoadInternalState._app_ref = None - mocker.patch( - "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} - ) app = app_module_mock.app = App(_state=State) + app._state_manager = mock_root_event_context.state_manager def index(): return "hello" @@ -3087,40 +3094,49 @@ def index(): app.add_page(index, on_load=test_state.test_handler) app._compile_page("index") - async with app.state_manager.modify_state(_substate_key(token, State)) as state: - state.router_data = {"simulate": "hydrate"} + on_load_internal_name = format.format_event_handler( + OnLoadInternalState.on_load_internal # pyright: ignore[reportArgumentType] + ) - updates = [] - async for update in rx.app.process( - app=app, - event=Event( - token=token, - name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}", - router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, - ), - sid="sid", - headers={}, - client_ip="", + async with mock_base_state_event_processor as processor: + on_load_future = await processor.enqueue( + token, + Event( + name=on_load_internal_name, + router_data={ + RouteVar.PATH: "/", + RouteVar.ORIGIN: "/", + RouteVar.QUERY: {}, + }, + ), + ) + await on_load_future.wait_all() + + # The processor chains all events: on_load_internal sets is_hydrated=False, + # then the on_load handler runs, then set_is_hydrated(True) runs. + # First delta: router + is_hydrated=False + assert len(emitted_deltas) == 1 + len(expected) + first_token, first_delta = emitted_deltas[0] + assert first_token == token + assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None + assert first_delta == exp_is_hydrated(State, False) + + # Find the deltas containing the test handler's state change + for (delta_token, actual_delta), expected_delta in zip( + emitted_deltas[1:], expected, strict=True ): - assert isinstance(update, StateUpdate) - updates.append(update) - assert len(updates) == 1 - assert updates[0].delta[State.get_name()].pop("router" + FIELD_MARKER) is not None - assert updates[0].delta == exp_is_hydrated(state, False) - - events = updates[0].events - assert len(events) == 2 - async for update in state._process(events[0]): - assert update.delta == {test_state.get_full_name(): {"num" + FIELD_MARKER: 1}} - async for update in state._process(events[1]): - assert update.delta == exp_is_hydrated(state) - - await app.state_manager.close() + assert delta_token == token + assert actual_delta == expected_delta @pytest.mark.asyncio async def test_preprocess_multiple_load_events( - app_module_mock, token, mocker: MockerFixture + app_module_mock, + token, + mocker: MockerFixture, + mock_root_event_context: EventContext, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, ): """Test that a state hydrate event for multiple on-load events is processed correctly. @@ -3128,67 +3144,81 @@ async def test_preprocess_multiple_load_events( app_module_mock: The app module that will be returned by get_app(). token: A token. mocker: pytest mock object. + mock_root_event_context: The mock root event context. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ OnLoadInternalState._app_ref = None - mocker.patch( - "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} - ) app = app_module_mock.app = App(_state=State) + app._state_manager = mock_root_event_context.state_manager def index(): return "hello" app.add_page(index, on_load=[OnLoadState.test_handler, OnLoadState.test_handler]) app._compile_page("index") - async with app.state_manager.modify_state(_substate_key(token, State)) as state: - state.router_data = {"simulate": "hydrate"} - - updates = [] - async for update in rx.app.process( - app=app, - event=Event( - token=token, - name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}", - router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, - ), - sid="sid", - headers={}, - client_ip="", - ): - assert isinstance(update, StateUpdate) - updates.append(update) - assert len(updates) == 1 - assert updates[0].delta[State.get_name()].pop("router" + FIELD_MARKER) is not None - assert updates[0].delta == exp_is_hydrated(state, False) - events = updates[0].events - assert len(events) == 3 - async for update in state._process(events[0]): - assert update.delta == {OnLoadState.get_full_name(): {"num" + FIELD_MARKER: 1}} - async for update in state._process(events[1]): - assert update.delta == {OnLoadState.get_full_name(): {"num" + FIELD_MARKER: 2}} - async for update in state._process(events[2]): - assert update.delta == exp_is_hydrated(state) + on_load_internal_name = format.format_event_handler( + OnLoadInternalState.on_load_internal # pyright: ignore[reportArgumentType] + ) - await app.state_manager.close() + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + name=on_load_internal_name, + router_data={ + RouteVar.PATH: "/", + RouteVar.ORIGIN: "/", + RouteVar.QUERY: {}, + }, + ), + ) + await processor.join() + + # First delta: router + is_hydrated=False + assert len(emitted_deltas) >= 2 + first_delta = emitted_deltas[0][1] + assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None + assert first_delta == exp_is_hydrated(State, False) + + # Find deltas containing the test handler's state change (num incremented twice) + handler_deltas = [ + d + for _, d in emitted_deltas + if OnLoadState.get_full_name() in d + and "num" + FIELD_MARKER in d[OnLoadState.get_full_name()] + ] + assert len(handler_deltas) == 2 + assert handler_deltas[0][OnLoadState.get_full_name()]["num" + FIELD_MARKER] == 1 + assert handler_deltas[1][OnLoadState.get_full_name()]["num" + FIELD_MARKER] == 2 + + # Find the delta that sets is_hydrated back to True + hydrated_deltas = [ + d + for _, d in emitted_deltas + if State.get_full_name() in d + and d[State.get_full_name()].get(CompileVars.IS_HYDRATED + FIELD_MARKER) is True + ] + assert len(hydrated_deltas) == 1 @pytest.mark.asyncio -async def test_get_state(mock_app: rx.App, token: str): +async def test_get_state(token: str, attached_mock_event_context: EventContext): """Test that a get_state populates the top level state and delta calculation is correct. Args: - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: An event context with a state manager that has a TestState instance corresponding to the token. """ - mock_app.state_manager.state = mock_app._state = TestState + state_manager = attached_mock_event_context.state_manager # Get instance of ChildState2. - test_state = await mock_app.state_manager.get_state( - _substate_key(token, ChildState2) + test_state = await state_manager.get_state( + BaseStateToken(ident=token, cls=ChildState2) ) assert isinstance(test_state, TestState) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # All substates are available assert tuple(sorted(test_state.substates)) == ( ChildState.get_name(), @@ -3248,11 +3278,11 @@ async def test_get_state(mock_app: rx.App, token: str): } # Get a fresh instance - new_test_state = await mock_app.state_manager.get_state( - _substate_key(token, ChildState2) + new_test_state = await state_manager.get_state( + BaseStateToken(ident=token, cls=ChildState2) ) assert isinstance(new_test_state, TestState) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # In memory, it's the same instance assert new_test_state is test_state test_state._clean() @@ -3289,7 +3319,9 @@ async def test_get_state(mock_app: rx.App, token: str): @pytest.mark.asyncio -async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): +async def test_get_state_from_sibling_not_cached( + token: str, attached_mock_event_context: EventContext +): """A test simulating update_vars_internal when setting cookies with computed vars. In that case, a sibling state, UpdateVarsInternalState handles the fetching @@ -3302,8 +3334,8 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): Explicit regression test for https://github.com/reflex-dev/reflex/issues/2851. Args: - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: An event context with a state manager that has a TestState instance corresponding to the token. """ class Parent(BaseState): @@ -3342,14 +3374,14 @@ class GreatGrandchild3(Grandchild3): has a computed var. """ - mock_app.state_manager.state = mock_app._state = Parent + state_manager = attached_mock_event_context.state_manager # Get the top level state via unconnected sibling. - root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + root = await state_manager.get_state(BaseStateToken(ident=token, cls=Child)) # Set value in parent_var to assert it does not get refetched later. root.parent_var = 1 - if isinstance(mock_app.state_manager, StateManagerRedis): + if isinstance(state_manager, StateManagerRedis): # When redis is used, only states with computed vars are pre-fetched. assert Child2.get_name() not in root.substates assert Child3.get_name() in root.substates # (due to @rx.var) @@ -3428,8 +3460,7 @@ def foo(self) -> str: ] # Get state from state manager. - state_manager.state = State - rx_state = await state_manager.get_state(_substate_key(token, State)) + rx_state = await state_manager.get_state(BaseStateToken(ident=token, cls=State)) assert RouterVarParentState.get_name() in rx_state.substates parent_state = rx_state.substates[RouterVarParentState.get_name()] assert RouterVarDepState.get_name() in parent_state.substates @@ -3445,29 +3476,46 @@ def foo(self) -> str: @pytest.mark.asyncio -async def test_setvar(mock_app: rx.App, token: str): +async def test_setvar( + state_manager: StateManager, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, +): """Test that setvar works correctly. Args: - mock_app: An app that will be returned by `get_app()` + state_manager: A state manager instance. token: A token. + mock_base_state_event_processor: The event processor. """ - state = await mock_app.state_manager.get_state(_substate_key(token, TestState)) - assert isinstance(state, TestState) - # Set Var in same state (with Var type casting) - for event in rx.event.fix_events( - [TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token - ): - async for update in state._process(event): - print(update) + events = Event.from_event_type([ + TestState.setvar("num1", 42), + TestState.setvar("num2", "4.2"), + ]) + async with mock_base_state_event_processor as processor: + for fut in asyncio.as_completed(await processor.enqueue_many(token, *events)): + await fut + await processor.join(1) + + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + + state = await state_manager.get_state(BaseStateToken(ident=token, cls=TestState)) + assert isinstance(state, TestState) assert state.num1 == 42 assert math.isclose(state.num2, 4.2) # Set Var in parent state - for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token): - async for update in state._process(event): - print(update) + events = Event.from_event_type([GrandchildState.setvar("array", [43])]) + async with mock_base_state_event_processor as processor: + await (await processor.enqueue(token, events[0])) + + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + + state = await state_manager.get_state(BaseStateToken(ident=token, cls=TestState)) + assert isinstance(state, TestState) assert state.array == [43] # Cannot setvar for non-existent var @@ -3552,9 +3600,8 @@ def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_ with chdir(proj_root): # reload config for each parameter to avoid stale values reflex_base.config.get_config(reload=True) - from reflex.state import State - state_manager = StateManagerRedis(state=State, redis=mock_redis()) + state_manager = StateManagerRedis(redis=mock_redis()) assert state_manager.lock_expiration == expected_values[0] # pyright: ignore [reportAttributeAccessIssue] assert state_manager.token_expiration == expected_values[1] # pyright: ignore [reportAttributeAccessIssue] assert state_manager.lock_warning_threshold == expected_values[2] # pyright: ignore [reportAttributeAccessIssue] @@ -3589,10 +3636,9 @@ def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( with chdir(proj_root): # reload config for each parameter to avoid stale values reflex_base.config.get_config(reload=True) - from reflex.state import State with pytest.raises(InvalidLockWarningThresholdError): - StateManagerRedis(state=State, redis=mock_redis()) + StateManagerRedis(redis=mock_redis()) del sys.modules[constants.Config.MODULE] @@ -3616,9 +3662,7 @@ def test_state_manager_create_respects_explicit_memory_mode_with_redis_url( with chdir(proj_root): reflex_base.config.get_config(reload=True) monkeypatch.setattr(prerequisites, "get_redis", mock_redis) - from reflex.state import State - - state_manager = StateManager.create(state=State) + state_manager = StateManager.create() assert isinstance(state_manager, StateManagerMemory) del sys.modules[constants.Config.MODULE] @@ -3840,8 +3884,10 @@ class State(Root): class Child(State): foo: str = "bar" - dsm = StateManagerDisk(state=Root) - async with dsm.modify_state(token) as root: + bs_token = BaseStateToken(ident=token, cls=Root) + + dsm = StateManagerDisk() + async with dsm.modify_state(bs_token) as root: s = await root.get_state(State) s.num += 1 c = await root.get_state(Child) @@ -3849,8 +3895,8 @@ class Child(State): assert not c._get_was_touched() await dsm.close() - dsm2 = StateManagerDisk(state=Root) - root = await dsm2.get_state(token) + dsm2 = StateManagerDisk() + root = await dsm2.get_state(bs_token) s = await root.get_state(State) assert s.num == 43 c = await root.get_state(Child) @@ -4152,7 +4198,6 @@ def py_unresolvable(self, u: Unresolvable): # noqa: D102, F821 # pyright: ignor @pytest.mark.asyncio -@pytest.mark.usefixtures("mock_app_simple") @pytest.mark.parametrize( ("handler", "payload"), [ @@ -4171,22 +4216,38 @@ def py_unresolvable(self, u: Unresolvable): # noqa: D102, F821 # pyright: ignor (UpcastState.py_unresolvable, {"u": ["foo"]}), ], ) -async def test_upcast_event_handler_arg(handler, payload): +async def test_upcast_event_handler_arg( + handler, + payload, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): """Test that upcast event handler args work correctly. Args: handler: The handler to test. payload: The payload to test. + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ - state = UpcastState() - async for update in state._process_event(handler, state, payload): - assert update.delta == { - UpcastState.get_full_name(): {"passed" + FIELD_MARKER: True} - } + event = Event( + name=format.format_event_handler(handler), + payload=payload, + ) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + assert len(emitted_deltas) == 1 + assert emitted_deltas[0][1] == { + UpcastState.get_full_name(): {"passed" + FIELD_MARKER: True} + } @pytest.mark.asyncio -async def test_get_var_value(state_manager: StateManager, substate_token: str): +async def test_get_var_value( + state_manager: StateManager, substate_token: BaseStateToken +): """Test that get_var_value works correctly. Args: @@ -4221,12 +4282,14 @@ async def test_get_var_value(state_manager: StateManager, substate_token: str): @pytest.mark.asyncio -async def test_async_computed_var_get_state(mock_app: rx.App, token: str): +async def test_async_computed_var_get_state( + token: str, attached_mock_event_context: EventContext +): """A test where an async computed var depends on a var in another state. Args: - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: An event context that will be attached to the app's state manager. """ class Parent(BaseState): @@ -4259,14 +4322,14 @@ async def v(self) -> int: child3 = await self.get_state(Child3) return child3.child3_var + p.parent_var - mock_app.state_manager.state = mock_app._state = Parent + state_manager = attached_mock_event_context.state_manager # Get the top level state via unconnected sibling. - root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + root = await state_manager.get_state(BaseStateToken(ident=token, cls=Child)) # Set value in parent_var to assert it does not get refetched later. root.parent_var = 1 - if isinstance(mock_app.state_manager, StateManagerRedis): + if isinstance(state_manager, StateManagerRedis): # When redis is used, only states with uncached computed vars are pre-fetched. assert Child2.get_name() not in root.substates assert Child3.get_name() not in root.substates @@ -4341,9 +4404,11 @@ class OtherState(rx.State): data: list[dict[str, Any]] = [{"foo": "bar"}] - mock_app.state_manager.state = mock_app._state = rx.State + mock_app._state = rx.State comp = Table.create(data=OtherState.data) - state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + state = await mock_app.state_manager.get_state( + BaseStateToken(ident=token, cls=OtherState) + ) other_state = await state.get_state(OtherState) assert comp.State is not None # The state should have been pre-cached from the dependency. @@ -4376,7 +4441,9 @@ class SecondCvState(CvMixin, rx.State): @pytest.mark.asyncio -async def test_add_dependency_get_state_regression(mock_app: rx.App, token: str): +async def test_add_dependency_get_state_regression( + token: str, attached_mock_event_context: EventContext, mock_app: rx.App +): """Ensure that a state class can be fetched separately when it's is explicit dep.""" class DataState(rx.State): @@ -4401,8 +4468,9 @@ class OtherState(rx.State): async def fetch_data_state(self) -> None: print(await self.get_state(DataState)) - mock_app.state_manager.state = mock_app._state = rx.State - state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + state = await attached_mock_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=OtherState) + ) other_state = await state.get_state(OtherState) await other_state.fetch_data_state() # Should not raise exception. @@ -4414,11 +4482,14 @@ class MutableProxyState(BaseState): @pytest.mark.asyncio -async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: +async def test_rebind_mutable_proxy( + token: str, attached_mock_event_context: EventContext +) -> None: """Test that previously bound MutableProxy instances can be rebound correctly.""" - mock_app.state_manager.state = mock_app._state = MutableProxyState - async with mock_app.state_manager.modify_state( - _substate_key(token, MutableProxyState) + state_manager = attached_mock_event_context.state_manager + + async with state_manager.modify_state( + BaseStateToken(ident=token, cls=MutableProxyState) ) as state: state.router = RouterData.from_router_data({ "query": {}, @@ -4442,7 +4513,7 @@ async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: assert not isinstance(state_proxy.__wrapped__.data["a"], ImmutableMutableProxy) # Flush any oplock. - await mock_app.state_manager.close() + await state_manager.close() new_state_proxy = StateProxy(state) assert state_proxy is not new_state_proxy @@ -4453,8 +4524,8 @@ async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: async with state_proxy: state_proxy.data["a"].append(3) - async with mock_app.state_manager.modify_state( - _substate_key(token, MutableProxyState) + async with state_manager.modify_state( + BaseStateToken(ident=token, cls=MutableProxyState) ) as state: assert isinstance(state, MutableProxyState) assert state.data["a"] == [2, 3] diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index 74e450133b3..2e72d83fc8b 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -9,7 +9,8 @@ import reflex as rx from reflex.istate.manager import StateManager from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState class Root(BaseState): @@ -372,7 +373,9 @@ async def test_get_state_tree( exp_root_substates: The expected substates of the root state. exp_root_dict_keys: The expected keys of the root state dict. """ - state = await state_manager_redis.get_state(_substate_key(token, substate_cls)) + state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=substate_cls) + ) assert isinstance(state, Root) assert sorted(state.substates) == sorted(exp_root_substates)