diff --git a/reflex/app.py b/reflex/app.py index bdbb90bfe40..5d373e0418b 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -16,6 +16,7 @@ import traceback import urllib.parse from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence +from contextvars import Token from datetime import datetime from itertools import chain from pathlib import Path @@ -31,6 +32,7 @@ evaluate_style_namespaces, ) from reflex_base.config import get_config +from reflex_base.context.base import BaseContext from reflex_base.environment import ExecutorType, environment from reflex_base.event import ( _EVENT_FIELDS, @@ -40,6 +42,7 @@ IndividualEventType, noop, ) +from reflex_base.event.context import EventContext from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor from reflex_base.registry import RegistrationContext from reflex_base.utils import console @@ -574,8 +577,65 @@ async def modified_send(message: Message): # Ensure the event processor starts and stops with the server. self.register_lifespan_task(self._setup_event_processor) - def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp: - """Ensure the RegistrationContext is attached to the ASGI app. + def _set_contexts_internal(self) -> dict[type[BaseContext], Token]: + """Set Reflex contexts if not already present, returning reset tokens. + + Returns: + A dict mapping context class to the contextvars Token for each + context that was set. Empty if all contexts were already present. + """ + tokens: dict[type[BaseContext], Token] = {} + + if self._registration_context is not None: + try: + RegistrationContext.get() + except LookupError: + tokens[RegistrationContext] = RegistrationContext.set( + self._registration_context + ) + + if ( + self._event_processor is not None + and self._event_processor._root_context is not None + ): + try: + EventContext.get() + except LookupError: + tokens[EventContext] = EventContext.set( + self._event_processor._root_context + ) + + return tokens + + def set_contexts(self) -> contextlib.AbstractContextManager: + """Set Reflex contexts needed for state and event processing. + + Pushes RegistrationContext and EventContext into the current + contextvars scope, but only if they are not already set. + + Can be used as a context manager:: + + with app.set_contexts(): + async with app.modify_state(token) as state: + ... + + Returns: + A context manager that resets any contexts that were set on exit. + """ + tokens = self._set_contexts_internal() + if not tokens: + return contextlib.nullcontext() + stack = contextlib.ExitStack() + for ctx_cls, tok in tokens.items(): + stack.callback(ctx_cls.reset, tok) + return stack + + def _context_middleware(self, app: ASGIApp) -> ASGIApp: + """Ensure Reflex contexts are attached for each ASGI request. + + Many ASGI servers start each request with a fresh contextvars scope, + so this middleware re-applies the RegistrationContext and EventContext + that are needed for Reflex state and event processing. Args: app: The ASGI app to attach the middleware to. @@ -584,14 +644,11 @@ def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp: The ASGI app with the middleware attached. """ - async def registration_context_middleware( - scope: Scope, receive: Receive, send: Send - ): - if self._registration_context is not None: - RegistrationContext.set(self._registration_context) + async def context_middleware(scope: Scope, receive: Receive, send: Send): + self._set_contexts_internal() await app(scope, receive, send) - return registration_context_middleware + return context_middleware @contextlib.asynccontextmanager async def _setup_event_processor(self) -> AsyncIterator[None]: @@ -669,10 +726,10 @@ def __call__(self) -> ASGIApp: asgi_app = api_transformer(asgi_app) top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks) - # Make sure the RegistrationContext is attached. + # Make sure Reflex contexts are attached for each request. top_asgi_app.mount( "", - self._registration_context_middleware(asgi_app), + self._context_middleware(asgi_app), ) App._add_cors(top_asgi_app) return top_asgi_app @@ -1607,20 +1664,22 @@ async def modify_state( if isinstance(token, str): token = BaseStateToken.from_legacy_token(token, root_state=self._state) - # Get exclusive access to the state. - async with self.state_manager.modify_state_with_links( - token, previous_dirty_vars=previous_dirty_vars, **context - ) as state: - # No other event handler can modify the state while in this context. - yield state - delta = await state._get_resolved_delta() - state._clean() - if delta: - # When the frontend vars are modified emit the delta to the frontend. - await self.event_namespace.emit_update( - update=StateUpdate(delta=delta), - token=token.ident, - ) + # Ensure Reflex contexts are available (e.g. when called from an API route). + with self.set_contexts(): + # Get exclusive access to the state. + async with self.state_manager.modify_state_with_links( + token, previous_dirty_vars=previous_dirty_vars, **context + ) as state: + # No other event handler can modify the state while in this context. + yield state + delta = await state._get_resolved_delta() + state._clean() + if delta: + # When the frontend vars are modified emit the delta to the frontend. + await self.event_namespace.emit_update( + update=StateUpdate(delta=delta), + token=token.ident, + ) def _validate_exception_handlers(self): """Validate the custom event exception handlers for front- and backend. diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 30c3b5c5fee..31f28ec763a 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -173,6 +173,45 @@ def _rehydrate(self): State.set_is_hydrated(True), ] + async def _resolve_linked_state( + self, state_cls: type["BaseState"], linked_token: str + ) -> "BaseState": + """Load and patch a linked state that was not pre-loaded in the tree. + + Called by State._get_state_from_redis when a state in + _reflex_internal_links is not yet in the cache. This loads the + private copy into the tree first, then patches the linked version + on top of it via _internal_patch_linked_state. + + Args: + state_cls: The shared state class to resolve. + linked_token: The shared token the state is linked to. + + Returns: + The linked state instance, patched into the current tree. + """ + root_state = self._get_root_state() + + # Load the private copy into the tree so _internal_patch_linked_state + # has an original to swap out (needed for unlink / restore). + await BaseState._get_state_from_redis(root_state, state_cls) + + # Retrieve the private instance that was just attached to the tree. + original_state = root_state._get_state_from_cache(state_cls) + + # If we are inside _modify_linked_states, we can properly patch the + # linked state using the exit stack for lock management. + if ( + self._exit_stack is not None + and self._held_locks is not None + and isinstance(original_state, SharedStateBaseInternal) + ): + return await original_state._internal_patch_linked_state(linked_token) + + # Outside _modify_linked_states context - return the private copy as + # a safe fallback (this path should rarely be hit in practice). + return original_state + async def _link_to(self, token: str) -> Self: """Link this shared state to a token. @@ -280,6 +319,18 @@ async def _internal_patch_linked_state( BaseStateToken(ident=token, cls=type(self)) ) ) + # Set client_token on the linked root so that subsequent get_state + # calls (which fall through to _get_state_from_redis) use the + # linked token rather than an empty/default value. + if linked_root_state.router.session.client_token != token: + import dataclasses as dc + + linked_root_state.router = dc.replace( + linked_root_state.router, + session=dc.replace( + linked_root_state.router.session, client_token=token + ), + ) self._held_locks.setdefault(token, {}) else: linked_root_state = await get_state_manager().get_state( @@ -386,6 +437,13 @@ async def _modify_linked_states( for token in linked_state._linked_from if token != self.router.session.client_token ) + # When modifying a shared token directly (empty _reflex_internal_links), + # the held locks will be empty. Check SharedState substates for linked + # clients that need to be notified. + if not self._reflex_internal_links: + self._collect_shared_token_updates( + affected_tokens, current_dirty_vars + ) finally: self._exit_stack = None @@ -397,6 +455,32 @@ async def _modify_linked_states( state_type=type(self), ) + def _collect_shared_token_updates( + self, + affected_tokens: set[str], + current_dirty_vars: dict[str, set[str]], + ) -> None: + """Collect dirty vars and linked clients from SharedState substates. + + When a shared state is modified directly by its shared token (rather than + through a private client token), the held locks are empty so the normal + collection loop above finds nothing. This method checks the SharedState + substates directly for linked clients that need to be notified. + + Args: + affected_tokens: Set to update with client tokens that need notification. + current_dirty_vars: Dict to update with dirty var mappings per state. + """ + for substate in self.substates.values(): + if not isinstance(substate, SharedState) or not substate._linked_from: + continue + if substate._previous_dirty_vars: + current_dirty_vars[substate.get_full_name()] = set( + substate._previous_dirty_vars + ) + if substate._get_was_touched() or substate._previous_dirty_vars: + affected_tokens.update(substate._linked_from) + class SharedState(SharedStateBaseInternal, mixin=True): """Mixin for defining new shared states.""" diff --git a/reflex/state.py b/reflex/state.py index def7fcc382f..bbd78605c58 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2146,7 +2146,6 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: Returns: The instance of state_cls associated with this state's client_token. """ - state_instance = await super()._get_state_from_redis(state_cls) if ( self._reflex_internal_links and ( @@ -2155,15 +2154,12 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: ) ) is not None - and ( - internal_patch_linked_state := getattr( - state_instance, "_internal_patch_linked_state", None - ) - ) - is not None ): - return await internal_patch_linked_state(linked_token) - return state_instance + from reflex.istate.shared import SharedStateBaseInternal + + shared_base = await self.get_state(SharedStateBaseInternal) + return await shared_base._resolve_linked_state(state_cls, linked_token) # type: ignore[return-value] + return await super()._get_state_from_redis(state_cls) @event async def hydrate(self) -> None: diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py index 4f3d5b14453..e3d234dc6e0 100644 --- a/tests/integration/test_linked_state.py +++ b/tests/integration/test_linked_state.py @@ -5,9 +5,12 @@ import uuid from collections.abc import Callable, Generator +import httpx import pytest +from reflex_base.config import get_config from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys +from selenium.webdriver.remote.webelement import WebElement from reflex.testing import AppHarness, WebDriver @@ -63,7 +66,18 @@ async def handle_submit(self, form_data: dict[str, Any]): if "token" in form_data: await self.link_to(form_data["token"]) + class SharedNotes(rx.SharedState): + """A second SharedState to test multi-SharedState propagation.""" + + note: str = "" + + @rx.event + async def on_load_link_default(self): + await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] + class PrivateState(rx.State): + fetched_note: str = "" + @rx.var async def greeting(self) -> str: ss = await self.get_state(SharedState) @@ -74,6 +88,12 @@ async def linked_to(self) -> str: ss = await self.get_state(SharedState) return ss._linked_to + @rx.event + async def fetch_shared_note(self): + """Fetch SharedNotes via get_state from an unrelated state handler.""" + sn = await self.get_state(SharedNotes) + self.fetched_note = sn.note + @rx.event(background=True) async def bump_counter_bg(self): for _ in range(5): @@ -140,10 +160,38 @@ def index() -> rx.Component: on_click=SharedState.link_to_and_increment, id="link-increment-button", ), + rx.text(SharedNotes.note, id="shared-note"), + rx.button( + "Fetch Note via get_state", + on_click=PrivateState.fetch_shared_note, + id="fetch-note-button", + ), + rx.text(PrivateState.fetched_note, id="fetched-note"), ) - app = rx.App() - app.add_page(index, route="/room/[room]", on_load=SharedState.on_load_link_default) + from fastapi import FastAPI + + api = FastAPI() + + @api.get("/api/set-counter/{shared_token}/{value}") + async def set_counter_api(shared_token: str, value: int): + """Modify shared state by its shared token from an API route.""" + from reflex.istate.manager.token import BaseStateToken + + async with app.modify_state( + BaseStateToken(ident=shared_token, cls=SharedState), + ) as state: + ss = await state.get_state(SharedState) + ss.counter = value + notes = await state.get_state(SharedNotes) + notes.note = f"counter set to {value}" + + app = rx.App(api_transformer=api) + app.add_page( + index, + route="/room/[room]", + on_load=[SharedState.on_load_link_default, SharedNotes.on_load_link_default], + ) app.add_page(index) @@ -386,3 +434,126 @@ def test_linked_state( # Link to a new state and increment the counter in the same event tab1.find_element(By.ID, "link-increment-button").click() assert linked_state.poll_for_content(counter_button_1, exp_not_equal="3") == "1" + + +def _open_linked_tab( + harness: AppHarness, + tab_factory: Callable[[], WebDriver], + shared_token: str, +) -> tuple[WebElement, WebElement]: + """Open a new tab linked to a shared token and return key elements. + + Args: + harness: The running AppHarness. + tab_factory: Factory to create WebDriver instances. + shared_token: The shared token to link to via on_load. + + Returns: + Tuple of (counter_button, note_element). + """ + tab = tab_factory() + tab.get(f"{harness.frontend_url}room/{shared_token}") + ss = utils.SessionStorage(tab) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + counter_button = AppHarness._poll_for( + lambda: tab.find_element(By.ID, "counter-button") + ) + assert counter_button + assert harness.poll_for_content(counter_button) == "0" + # Wait for SharedState.on_load_link_default to complete (linked-to shows the token). + linked_to = tab.find_element(By.ID, "linked-to") + assert harness.poll_for_content(linked_to) == shared_token + note = tab.find_element(By.ID, "shared-note") + assert note.text == "" + return counter_button, note + + +def test_modify_shared_state_by_shared_token( + linked_state: AppHarness, + tab_factory: Callable[[], WebDriver], +): + """Test that modifying shared state by shared token propagates to all linked clients. + + This exercises the use case of modifying shared state from an API route + where only the shared token is known (no private client token). + + Args: + linked_state: harness for LinkedStateApp. + tab_factory: factory to create WebDriver instances. + """ + assert linked_state.app_instance is not None + + shared_token = f"api-test-{uuid.uuid4()}" + + # Open two tabs linked to the same shared token via on_load + counter_button_1, note_1 = _open_linked_tab(linked_state, tab_factory, shared_token) + counter_button_2, note_2 = _open_linked_tab(linked_state, tab_factory, shared_token) + + # Modify both shared states by shared token via API route + api_url = f"{get_config().api_url}/api/set-counter/{shared_token}/42" + response = httpx.get(api_url) + assert response.status_code == 200 + + # Both tabs should see updates to both SharedState and SharedNotes + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "42" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "42" + assert ( + linked_state.poll_for_content(note_1, exp_not_equal="") == "counter set to 42" + ) + assert ( + linked_state.poll_for_content(note_2, exp_not_equal="") == "counter set to 42" + ) + + # After the API-driven update, normal event handlers should still work + counter_button_1.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="42") == "43" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="42") == "43" + + +def test_get_state_returns_linked_state( + linked_state: AppHarness, + tab_factory: Callable[[], WebDriver], +): + """Test that get_state from an unrelated handler returns the linked instance. + + When SharedNotes is linked to a shared token, calling + ``await self.get_state(SharedNotes)`` from PrivateState (an unrelated + state handler) should return the linked SharedNotes — not the private + copy. With the Redis state manager, SharedNotes may not be pre-loaded + in the tree, so ``get_state`` falls through to the redis fetch path. + + Args: + linked_state: harness for LinkedStateApp. + tab_factory: factory to create WebDriver instances. + """ + assert linked_state.app_instance is not None + + shared_token = f"get-state-test-{uuid.uuid4()}" + + # Open a tab linked to the shared token via on_load + tab = tab_factory() + tab.get(f"{linked_state.frontend_url}room/{shared_token}") + ss = utils.SessionStorage(tab) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + linked_to = AppHarness._poll_for(lambda: tab.find_element(By.ID, "linked-to")) + assert linked_to + # Wait for on_load to link SharedState (confirms event processing started). + assert linked_state.poll_for_content(linked_to) == shared_token + + # Modify SharedNotes.note on the shared token directly via the API. + api_url = f"{get_config().api_url}/api/set-counter/{shared_token}/99" + response = httpx.get(api_url) + assert response.status_code == 200 + + # Verify the linked note appears on the page (direct SharedNotes binding). + note = tab.find_element(By.ID, "shared-note") + assert linked_state.poll_for_content(note, exp_not_equal="") == "counter set to 99" + + # Now trigger PrivateState.fetch_shared_note — this calls get_state(SharedNotes) + # from an unrelated state handler. The returned instance must be the + # *linked* SharedNotes (with note="counter set to 99"), not the private copy. + tab.find_element(By.ID, "fetch-note-button").click() + fetched = tab.find_element(By.ID, "fetched-note") + assert ( + linked_state.poll_for_content(fetched, exp_not_equal="") == "counter set to 99" + ) diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 2a3eb5556f4..c3e36a1df21 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import contextvars import functools import io import json @@ -2611,3 +2612,112 @@ class Sub(Base): ) else: assert app._event_namespace.emit_update.call_count == 0 + + +@pytest.fixture +def app_with_processor() -> App: + """Create an App with a mocked event processor that has a root context. + + Returns: + An App instance with a mock event processor and root context. + """ + app = App(_state=EmptyState) + root_context = EventContext( + token="", + state_manager=StateManagerMemory(), + enqueue_impl=AsyncMock(), + ) + processor = Mock() + processor._root_context = root_context + app._event_processor = processor + return app + + +def _run_isolated(fn): + """Run fn in a fresh empty context so all contextvars start unset. + + Args: + fn: A zero-argument callable to run. + """ + contextvars.Context().run(fn) + + +@pytest.mark.parametrize( + ("pre_set_registration", "pre_set_event"), + [ + pytest.param(False, False, id="neither_set"), + pytest.param(True, False, id="registration_already_set"), + pytest.param(False, True, id="event_already_set"), + pytest.param(True, True, id="both_already_set"), + ], +) +def test_set_contexts( + app_with_processor: App, + pre_set_registration: bool, + pre_set_event: bool, +): + """set_contexts sets absent contexts, preserves existing ones, and resets on exit.""" + + def _test(): + existing_reg = None + existing_ev = None + + if pre_set_registration: + existing_reg = RegistrationContext() + RegistrationContext.set(existing_reg) + if pre_set_event: + existing_ev = EventContext( + token="pre-existing", + state_manager=StateManagerMemory(), + enqueue_impl=AsyncMock(), + ) + EventContext.set(existing_ev) + + with app_with_processor.set_contexts(): + # Pre-existing contexts are preserved; absent ones are filled in. + if existing_reg is not None: + assert RegistrationContext.get() is existing_reg + else: + assert ( + RegistrationContext.get() + is app_with_processor._registration_context + ) + + if existing_ev is not None: + assert EventContext.get() is existing_ev + else: + assert app_with_processor._event_processor is not None + assert ( + EventContext.get() + is app_with_processor._event_processor._root_context + ) + + # After exit: pushed contexts are reset, pre-existing ones remain. + if existing_reg is not None: + assert RegistrationContext.get() is existing_reg + else: + with pytest.raises(LookupError): + RegistrationContext.get() + + if existing_ev is not None: + assert EventContext.get() is existing_ev + else: + with pytest.raises(LookupError): + EventContext.get() + + _run_isolated(_test) + + +def test_set_contexts_no_event_processor(): + """When event processor is None, EventContext should not be touched.""" + + def _test(): + app = App(_state=EmptyState) + assert app._event_processor is None + + with app.set_contexts(): + assert RegistrationContext.get() is app._registration_context + with pytest.raises(LookupError): + EventContext.get() + + _run_isolated(_test)