From 662dcd7f6111ec3f2220fa9f1e94d97e9412bdce Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 16 Apr 2026 01:15:30 -1000 Subject: [PATCH 1/4] ENG-9350: App.modify_state can directly modify SharedState tokens MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a SharedState is modified directly by its shared token (e.g. from an API route webhook), propagate dirty vars to all clients linked to that shared state — matching the behavior that already exists when modifying shared state through a private client token. - Add _collect_shared_token_updates on SharedStateBaseInternal to detect the shared-token case inside _modify_linked_states and propagate to linked clients via the existing _do_update_other_tokens mechanism - Add App.set_contexts / App._set_contexts_internal to centralize pushing RegistrationContext and EventContext into the current contextvars scope, replacing the old _registration_context_middleware - App.modify_state now calls set_contexts so out-of-band callers (API routes, webhooks) get the necessary contexts automatically - Integration test: API endpoint modifies two SharedState subclasses by shared token, asserts both propagate to two linked browser tabs, and verifies normal event handlers still work afterward - Unit tests for set_contexts covering all combinations of pre-existing / absent contexts, no-event-processor, and reset-on-exit behavior Closes #6335 --- reflex/app.py | 107 ++++++++++++++++++------ reflex/istate/shared.py | 33 ++++++++ tests/integration/test_linked_state.py | 109 +++++++++++++++++++++++- tests/units/test_app.py | 110 +++++++++++++++++++++++++ 4 files changed, 333 insertions(+), 26 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index 389d18f246a..82a32eb29f6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -16,6 +16,7 @@ import traceback import urllib.parse from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence +from contextvars import Token from datetime import datetime from itertools import chain from pathlib import Path @@ -31,6 +32,7 @@ evaluate_style_namespaces, ) from reflex_base.config import get_config +from reflex_base.context.base import BaseContext from reflex_base.environment import ExecutorType, environment from reflex_base.event import ( _EVENT_FIELDS, @@ -40,6 +42,7 @@ IndividualEventType, noop, ) +from reflex_base.event.context import EventContext from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor from reflex_base.registry import RegistrationContext from reflex_base.utils import console @@ -574,8 +577,65 @@ async def modified_send(message: Message): # Ensure the event processor starts and stops with the server. self.register_lifespan_task(self._setup_event_processor) - def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp: - """Ensure the RegistrationContext is attached to the ASGI app. + def _set_contexts_internal(self) -> dict[type[BaseContext], Token]: + """Set Reflex contexts if not already present, returning reset tokens. + + Returns: + A dict mapping context class to the contextvars Token for each + context that was set. Empty if all contexts were already present. + """ + tokens: dict[type[BaseContext], Token] = {} + + if self._registration_context is not None: + try: + RegistrationContext.get() + except LookupError: + tokens[RegistrationContext] = RegistrationContext.set( + self._registration_context + ) + + if ( + self._event_processor is not None + and self._event_processor._root_context is not None + ): + try: + EventContext.get() + except LookupError: + tokens[EventContext] = EventContext.set( + self._event_processor._root_context + ) + + return tokens + + def set_contexts(self) -> contextlib.AbstractContextManager: + """Set Reflex contexts needed for state and event processing. + + Pushes RegistrationContext and EventContext into the current + contextvars scope, but only if they are not already set. + + Can be used as a context manager:: + + with app.set_contexts(): + async with app.modify_state(token) as state: + ... + + Returns: + A context manager that resets any contexts that were set on exit. + """ + tokens = self._set_contexts_internal() + if not tokens: + return contextlib.nullcontext() + stack = contextlib.ExitStack() + for ctx_cls, tok in tokens.items(): + stack.callback(ctx_cls.reset, tok) + return stack + + def _context_middleware(self, app: ASGIApp) -> ASGIApp: + """Ensure Reflex contexts are attached for each ASGI request. + + Many ASGI servers start each request with a fresh contextvars scope, + so this middleware re-applies the RegistrationContext and EventContext + that are needed for Reflex state and event processing. Args: app: The ASGI app to attach the middleware to. @@ -584,14 +644,11 @@ def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp: The ASGI app with the middleware attached. """ - async def registration_context_middleware( - scope: Scope, receive: Receive, send: Send - ): - if self._registration_context is not None: - RegistrationContext.set(self._registration_context) + async def context_middleware(scope: Scope, receive: Receive, send: Send): + self._set_contexts_internal() await app(scope, receive, send) - return registration_context_middleware + return context_middleware @contextlib.asynccontextmanager async def _setup_event_processor(self) -> AsyncIterator[None]: @@ -669,10 +726,10 @@ def __call__(self) -> ASGIApp: asgi_app = api_transformer(asgi_app) top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks) - # Make sure the RegistrationContext is attached. + # Make sure Reflex contexts are attached for each request. top_asgi_app.mount( "", - self._registration_context_middleware(asgi_app), + self._context_middleware(asgi_app), ) App._add_cors(top_asgi_app) return top_asgi_app @@ -1600,20 +1657,22 @@ async def modify_state( if isinstance(token, str): token = BaseStateToken.from_legacy_token(token, root_state=self._state) - # Get exclusive access to the state. - async with self.state_manager.modify_state_with_links( - token, previous_dirty_vars=previous_dirty_vars, **context - ) as state: - # No other event handler can modify the state while in this context. - yield state - delta = await state._get_resolved_delta() - state._clean() - if delta: - # When the frontend vars are modified emit the delta to the frontend. - await self.event_namespace.emit_update( - update=StateUpdate(delta=delta), - token=token.ident, - ) + # Ensure Reflex contexts are available (e.g. when called from an API route). + with self.set_contexts(): + # Get exclusive access to the state. + async with self.state_manager.modify_state_with_links( + token, previous_dirty_vars=previous_dirty_vars, **context + ) as state: + # No other event handler can modify the state while in this context. + yield state + delta = await state._get_resolved_delta() + state._clean() + if delta: + # When the frontend vars are modified emit the delta to the frontend. + await self.event_namespace.emit_update( + update=StateUpdate(delta=delta), + token=token.ident, + ) def _validate_exception_handlers(self): """Validate the custom event exception handlers for front- and backend. diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 30c3b5c5fee..58e1b46ff02 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -386,6 +386,13 @@ async def _modify_linked_states( for token in linked_state._linked_from if token != self.router.session.client_token ) + # When modifying a shared token directly (empty _reflex_internal_links), + # the held locks will be empty. Check SharedState substates for linked + # clients that need to be notified. + if not self._reflex_internal_links: + self._collect_shared_token_updates( + affected_tokens, current_dirty_vars + ) finally: self._exit_stack = None @@ -397,6 +404,32 @@ async def _modify_linked_states( state_type=type(self), ) + def _collect_shared_token_updates( + self, + affected_tokens: set[str], + current_dirty_vars: dict[str, set[str]], + ) -> None: + """Collect dirty vars and linked clients from SharedState substates. + + When a shared state is modified directly by its shared token (rather than + through a private client token), the held locks are empty so the normal + collection loop above finds nothing. This method checks the SharedState + substates directly for linked clients that need to be notified. + + Args: + affected_tokens: Set to update with client tokens that need notification. + current_dirty_vars: Dict to update with dirty var mappings per state. + """ + for substate in self.substates.values(): + if not isinstance(substate, SharedState) or not substate._linked_from: + continue + if substate._previous_dirty_vars: + current_dirty_vars[substate.get_full_name()] = set( + substate._previous_dirty_vars + ) + if substate._get_was_touched() or substate._previous_dirty_vars: + affected_tokens.update(substate._linked_from) + class SharedState(SharedStateBaseInternal, mixin=True): """Mixin for defining new shared states.""" diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py index 4f3d5b14453..383cd8b0472 100644 --- a/tests/integration/test_linked_state.py +++ b/tests/integration/test_linked_state.py @@ -5,9 +5,12 @@ import uuid from collections.abc import Callable, Generator +import httpx import pytest +from reflex_base.config import get_config from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys +from selenium.webdriver.remote.webelement import WebElement from reflex.testing import AppHarness, WebDriver @@ -63,6 +66,15 @@ async def handle_submit(self, form_data: dict[str, Any]): if "token" in form_data: await self.link_to(form_data["token"]) + class SharedNotes(rx.SharedState): + """A second SharedState to test multi-SharedState propagation.""" + + note: str = "" + + @rx.event + async def on_load_link_default(self): + await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] + class PrivateState(rx.State): @rx.var async def greeting(self) -> str: @@ -140,10 +152,32 @@ def index() -> rx.Component: on_click=SharedState.link_to_and_increment, id="link-increment-button", ), + rx.text(SharedNotes.note, id="shared-note"), ) - app = rx.App() - app.add_page(index, route="/room/[room]", on_load=SharedState.on_load_link_default) + from fastapi import FastAPI + + api = FastAPI() + + @api.get("/api/set-counter/{shared_token}/{value}") + async def set_counter_api(shared_token: str, value: int): + """Modify shared state by its shared token from an API route.""" + from reflex.istate.manager.token import BaseStateToken + + async with app.modify_state( + BaseStateToken(ident=shared_token, cls=SharedState), + ) as state: + ss = await state.get_state(SharedState) + ss.counter = value + notes = await state.get_state(SharedNotes) + notes.note = f"counter set to {value}" + + app = rx.App(api_transformer=api) + app.add_page( + index, + route="/room/[room]", + on_load=[SharedState.on_load_link_default, SharedNotes.on_load_link_default], + ) app.add_page(index) @@ -386,3 +420,74 @@ def test_linked_state( # Link to a new state and increment the counter in the same event tab1.find_element(By.ID, "link-increment-button").click() assert linked_state.poll_for_content(counter_button_1, exp_not_equal="3") == "1" + + +def _open_linked_tab( + harness: AppHarness, + tab_factory: Callable[[], WebDriver], + shared_token: str, +) -> tuple[WebElement, WebElement]: + """Open a new tab linked to a shared token and return key elements. + + Args: + harness: The running AppHarness. + tab_factory: Factory to create WebDriver instances. + shared_token: The shared token to link to via on_load. + + Returns: + Tuple of (counter_button, note_element). + """ + tab = tab_factory() + tab.get(f"{harness.frontend_url}room/{shared_token}") + ss = utils.SessionStorage(tab) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + counter_button = AppHarness._poll_for( + lambda: tab.find_element(By.ID, "counter-button") + ) + assert counter_button + assert harness.poll_for_content(counter_button) == "0" + note = tab.find_element(By.ID, "shared-note") + assert note.text == "" + return counter_button, note + + +def test_modify_shared_state_by_shared_token( + linked_state: AppHarness, + tab_factory: Callable[[], WebDriver], +): + """Test that modifying shared state by shared token propagates to all linked clients. + + This exercises the use case of modifying shared state from an API route + where only the shared token is known (no private client token). + + Args: + linked_state: harness for LinkedStateApp. + tab_factory: factory to create WebDriver instances. + """ + assert linked_state.app_instance is not None + + shared_token = f"api-test-{uuid.uuid4()}" + + # Open two tabs linked to the same shared token via on_load + counter_button_1, note_1 = _open_linked_tab(linked_state, tab_factory, shared_token) + counter_button_2, note_2 = _open_linked_tab(linked_state, tab_factory, shared_token) + + # Modify both shared states by shared token via API route + api_url = f"{get_config().api_url}/api/set-counter/{shared_token}/42" + response = httpx.get(api_url) + assert response.status_code == 200 + + # Both tabs should see updates to both SharedState and SharedNotes + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "42" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "42" + assert ( + linked_state.poll_for_content(note_1, exp_not_equal="") == "counter set to 42" + ) + assert ( + linked_state.poll_for_content(note_2, exp_not_equal="") == "counter set to 42" + ) + + # After the API-driven update, normal event handlers should still work + counter_button_1.click() + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="42") == "43" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="42") == "43" diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 2a3eb5556f4..c3e36a1df21 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import contextvars import functools import io import json @@ -2611,3 +2612,112 @@ class Sub(Base): ) else: assert app._event_namespace.emit_update.call_count == 0 + + +@pytest.fixture +def app_with_processor() -> App: + """Create an App with a mocked event processor that has a root context. + + Returns: + An App instance with a mock event processor and root context. + """ + app = App(_state=EmptyState) + root_context = EventContext( + token="", + state_manager=StateManagerMemory(), + enqueue_impl=AsyncMock(), + ) + processor = Mock() + processor._root_context = root_context + app._event_processor = processor + return app + + +def _run_isolated(fn): + """Run fn in a fresh empty context so all contextvars start unset. + + Args: + fn: A zero-argument callable to run. + """ + contextvars.Context().run(fn) + + +@pytest.mark.parametrize( + ("pre_set_registration", "pre_set_event"), + [ + pytest.param(False, False, id="neither_set"), + pytest.param(True, False, id="registration_already_set"), + pytest.param(False, True, id="event_already_set"), + pytest.param(True, True, id="both_already_set"), + ], +) +def test_set_contexts( + app_with_processor: App, + pre_set_registration: bool, + pre_set_event: bool, +): + """set_contexts sets absent contexts, preserves existing ones, and resets on exit.""" + + def _test(): + existing_reg = None + existing_ev = None + + if pre_set_registration: + existing_reg = RegistrationContext() + RegistrationContext.set(existing_reg) + if pre_set_event: + existing_ev = EventContext( + token="pre-existing", + state_manager=StateManagerMemory(), + enqueue_impl=AsyncMock(), + ) + EventContext.set(existing_ev) + + with app_with_processor.set_contexts(): + # Pre-existing contexts are preserved; absent ones are filled in. + if existing_reg is not None: + assert RegistrationContext.get() is existing_reg + else: + assert ( + RegistrationContext.get() + is app_with_processor._registration_context + ) + + if existing_ev is not None: + assert EventContext.get() is existing_ev + else: + assert app_with_processor._event_processor is not None + assert ( + EventContext.get() + is app_with_processor._event_processor._root_context + ) + + # After exit: pushed contexts are reset, pre-existing ones remain. + if existing_reg is not None: + assert RegistrationContext.get() is existing_reg + else: + with pytest.raises(LookupError): + RegistrationContext.get() + + if existing_ev is not None: + assert EventContext.get() is existing_ev + else: + with pytest.raises(LookupError): + EventContext.get() + + _run_isolated(_test) + + +def test_set_contexts_no_event_processor(): + """When event processor is None, EventContext should not be touched.""" + + def _test(): + app = App(_state=EmptyState) + assert app._event_processor is None + + with app.set_contexts(): + assert RegistrationContext.get() is app._registration_context + with pytest.raises(LookupError): + EventContext.get() + + _run_isolated(_test) From 7cfd39b229761e2aae11d1abb7d2cd971e3ff4ba Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 Apr 2026 12:08:56 -1000 Subject: [PATCH 2/4] deal with race condition due to slower execution in CI environment set an initial value for the shared state to provide affirmative proof that the state was linked _before_ the API call was made. if only one state was linked when the HTTP request was made, then only one state would be updated by such request resulting in the observed behavior. --- tests/integration/test_linked_state.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py index 383cd8b0472..a8c20271d64 100644 --- a/tests/integration/test_linked_state.py +++ b/tests/integration/test_linked_state.py @@ -74,6 +74,8 @@ class SharedNotes(rx.SharedState): @rx.event async def on_load_link_default(self): await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] + if not self.note: + self.note = "linked" class PrivateState(rx.State): @rx.var @@ -447,7 +449,10 @@ def _open_linked_tab( assert counter_button assert harness.poll_for_content(counter_button) == "0" note = tab.find_element(By.ID, "shared-note") - assert note.text == "" + # Wait for SharedNotes.on_load_link_default to complete (sets note="linked"). + # This ensures both on_load handlers have finished before returning, since + # SharedNotes' handler runs after SharedState's and events are sequential. + assert harness.poll_for_content(note) == "linked" return counter_button, note @@ -481,10 +486,12 @@ def test_modify_shared_state_by_shared_token( assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "42" assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "42" assert ( - linked_state.poll_for_content(note_1, exp_not_equal="") == "counter set to 42" + linked_state.poll_for_content(note_1, exp_not_equal="linked") + == "counter set to 42" ) assert ( - linked_state.poll_for_content(note_2, exp_not_equal="") == "counter set to 42" + linked_state.poll_for_content(note_2, exp_not_equal="linked") + == "counter set to 42" ) # After the API-driven update, normal event handlers should still work From e5f513510dc2efaa1dd47926a0652271a9671098 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 Apr 2026 12:34:23 -1000 Subject: [PATCH 3/4] followup from previous commit but actually set the value on the newly linked state, not the old private state. derpy easy to make mistake with this _link_to mechanism =/ --- tests/integration/test_linked_state.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py index a8c20271d64..535be5f819c 100644 --- a/tests/integration/test_linked_state.py +++ b/tests/integration/test_linked_state.py @@ -58,6 +58,8 @@ async def on_load_link_default(self): assert linked_state._linked_to == self.room # pyright: ignore[reportAttributeAccessIssue] else: assert linked_state._linked_to == "default" + if linked_state.counter == 0: + linked_state.counter = -1 @rx.event async def handle_submit(self, form_data: dict[str, Any]): @@ -73,9 +75,9 @@ class SharedNotes(rx.SharedState): @rx.event async def on_load_link_default(self): - await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] - if not self.note: - self.note = "linked" + linked_state = await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] + if not linked_state.note: + linked_state.note = "linked" class PrivateState(rx.State): @rx.var @@ -447,11 +449,10 @@ def _open_linked_tab( lambda: tab.find_element(By.ID, "counter-button") ) assert counter_button - assert harness.poll_for_content(counter_button) == "0" + # Wait for SharedState.on_load_link_default (sets counter=-1). + assert harness.poll_for_content(counter_button, exp_not_equal="0") == "-1" note = tab.find_element(By.ID, "shared-note") - # Wait for SharedNotes.on_load_link_default to complete (sets note="linked"). - # This ensures both on_load handlers have finished before returning, since - # SharedNotes' handler runs after SharedState's and events are sequential. + # Wait for SharedNotes.on_load_link_default (sets note="linked"). assert harness.poll_for_content(note) == "linked" return counter_button, note @@ -483,8 +484,8 @@ def test_modify_shared_state_by_shared_token( assert response.status_code == 200 # Both tabs should see updates to both SharedState and SharedNotes - assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "42" - assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "42" + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="-1") == "42" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="-1") == "42" assert ( linked_state.poll_for_content(note_1, exp_not_equal="linked") == "counter set to 42" From 4e5aef83476f90b3b387b6463b5a13fbb3358d87 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 17 Apr 2026 13:33:29 -1000 Subject: [PATCH 4/4] handle fetching linked states via `BaseState.get_state` if the linked state is not fetched in the initial request, then make sure we check if it should be patched in when performing a `get_state` call --- reflex/istate/shared.py | 51 +++++++++++++++ reflex/state.py | 14 ++-- tests/integration/test_linked_state.py | 88 +++++++++++++++++++++----- 3 files changed, 129 insertions(+), 24 deletions(-) diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 58e1b46ff02..31f28ec763a 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -173,6 +173,45 @@ def _rehydrate(self): State.set_is_hydrated(True), ] + async def _resolve_linked_state( + self, state_cls: type["BaseState"], linked_token: str + ) -> "BaseState": + """Load and patch a linked state that was not pre-loaded in the tree. + + Called by State._get_state_from_redis when a state in + _reflex_internal_links is not yet in the cache. This loads the + private copy into the tree first, then patches the linked version + on top of it via _internal_patch_linked_state. + + Args: + state_cls: The shared state class to resolve. + linked_token: The shared token the state is linked to. + + Returns: + The linked state instance, patched into the current tree. + """ + root_state = self._get_root_state() + + # Load the private copy into the tree so _internal_patch_linked_state + # has an original to swap out (needed for unlink / restore). + await BaseState._get_state_from_redis(root_state, state_cls) + + # Retrieve the private instance that was just attached to the tree. + original_state = root_state._get_state_from_cache(state_cls) + + # If we are inside _modify_linked_states, we can properly patch the + # linked state using the exit stack for lock management. + if ( + self._exit_stack is not None + and self._held_locks is not None + and isinstance(original_state, SharedStateBaseInternal) + ): + return await original_state._internal_patch_linked_state(linked_token) + + # Outside _modify_linked_states context - return the private copy as + # a safe fallback (this path should rarely be hit in practice). + return original_state + async def _link_to(self, token: str) -> Self: """Link this shared state to a token. @@ -280,6 +319,18 @@ async def _internal_patch_linked_state( BaseStateToken(ident=token, cls=type(self)) ) ) + # Set client_token on the linked root so that subsequent get_state + # calls (which fall through to _get_state_from_redis) use the + # linked token rather than an empty/default value. + if linked_root_state.router.session.client_token != token: + import dataclasses as dc + + linked_root_state.router = dc.replace( + linked_root_state.router, + session=dc.replace( + linked_root_state.router.session, client_token=token + ), + ) self._held_locks.setdefault(token, {}) else: linked_root_state = await get_state_manager().get_state( diff --git a/reflex/state.py b/reflex/state.py index def7fcc382f..bbd78605c58 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2146,7 +2146,6 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: Returns: The instance of state_cls associated with this state's client_token. """ - state_instance = await super()._get_state_from_redis(state_cls) if ( self._reflex_internal_links and ( @@ -2155,15 +2154,12 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: ) ) is not None - and ( - internal_patch_linked_state := getattr( - state_instance, "_internal_patch_linked_state", None - ) - ) - is not None ): - return await internal_patch_linked_state(linked_token) - return state_instance + from reflex.istate.shared import SharedStateBaseInternal + + shared_base = await self.get_state(SharedStateBaseInternal) + return await shared_base._resolve_linked_state(state_cls, linked_token) # type: ignore[return-value] + return await super()._get_state_from_redis(state_cls) @event async def hydrate(self) -> None: diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py index 535be5f819c..e3d234dc6e0 100644 --- a/tests/integration/test_linked_state.py +++ b/tests/integration/test_linked_state.py @@ -58,8 +58,6 @@ async def on_load_link_default(self): assert linked_state._linked_to == self.room # pyright: ignore[reportAttributeAccessIssue] else: assert linked_state._linked_to == "default" - if linked_state.counter == 0: - linked_state.counter = -1 @rx.event async def handle_submit(self, form_data: dict[str, Any]): @@ -75,11 +73,11 @@ class SharedNotes(rx.SharedState): @rx.event async def on_load_link_default(self): - linked_state = await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] - if not linked_state.note: - linked_state.note = "linked" + await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] class PrivateState(rx.State): + fetched_note: str = "" + @rx.var async def greeting(self) -> str: ss = await self.get_state(SharedState) @@ -90,6 +88,12 @@ async def linked_to(self) -> str: ss = await self.get_state(SharedState) return ss._linked_to + @rx.event + async def fetch_shared_note(self): + """Fetch SharedNotes via get_state from an unrelated state handler.""" + sn = await self.get_state(SharedNotes) + self.fetched_note = sn.note + @rx.event(background=True) async def bump_counter_bg(self): for _ in range(5): @@ -157,6 +161,12 @@ def index() -> rx.Component: id="link-increment-button", ), rx.text(SharedNotes.note, id="shared-note"), + rx.button( + "Fetch Note via get_state", + on_click=PrivateState.fetch_shared_note, + id="fetch-note-button", + ), + rx.text(PrivateState.fetched_note, id="fetched-note"), ) from fastapi import FastAPI @@ -449,11 +459,12 @@ def _open_linked_tab( lambda: tab.find_element(By.ID, "counter-button") ) assert counter_button - # Wait for SharedState.on_load_link_default (sets counter=-1). - assert harness.poll_for_content(counter_button, exp_not_equal="0") == "-1" + assert harness.poll_for_content(counter_button) == "0" + # Wait for SharedState.on_load_link_default to complete (linked-to shows the token). + linked_to = tab.find_element(By.ID, "linked-to") + assert harness.poll_for_content(linked_to) == shared_token note = tab.find_element(By.ID, "shared-note") - # Wait for SharedNotes.on_load_link_default (sets note="linked"). - assert harness.poll_for_content(note) == "linked" + assert note.text == "" return counter_button, note @@ -484,18 +495,65 @@ def test_modify_shared_state_by_shared_token( assert response.status_code == 200 # Both tabs should see updates to both SharedState and SharedNotes - assert linked_state.poll_for_content(counter_button_1, exp_not_equal="-1") == "42" - assert linked_state.poll_for_content(counter_button_2, exp_not_equal="-1") == "42" + assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "42" + assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "42" assert ( - linked_state.poll_for_content(note_1, exp_not_equal="linked") - == "counter set to 42" + linked_state.poll_for_content(note_1, exp_not_equal="") == "counter set to 42" ) assert ( - linked_state.poll_for_content(note_2, exp_not_equal="linked") - == "counter set to 42" + linked_state.poll_for_content(note_2, exp_not_equal="") == "counter set to 42" ) # After the API-driven update, normal event handlers should still work counter_button_1.click() assert linked_state.poll_for_content(counter_button_1, exp_not_equal="42") == "43" assert linked_state.poll_for_content(counter_button_2, exp_not_equal="42") == "43" + + +def test_get_state_returns_linked_state( + linked_state: AppHarness, + tab_factory: Callable[[], WebDriver], +): + """Test that get_state from an unrelated handler returns the linked instance. + + When SharedNotes is linked to a shared token, calling + ``await self.get_state(SharedNotes)`` from PrivateState (an unrelated + state handler) should return the linked SharedNotes — not the private + copy. With the Redis state manager, SharedNotes may not be pre-loaded + in the tree, so ``get_state`` falls through to the redis fetch path. + + Args: + linked_state: harness for LinkedStateApp. + tab_factory: factory to create WebDriver instances. + """ + assert linked_state.app_instance is not None + + shared_token = f"get-state-test-{uuid.uuid4()}" + + # Open a tab linked to the shared token via on_load + tab = tab_factory() + tab.get(f"{linked_state.frontend_url}room/{shared_token}") + ss = utils.SessionStorage(tab) + assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" + linked_to = AppHarness._poll_for(lambda: tab.find_element(By.ID, "linked-to")) + assert linked_to + # Wait for on_load to link SharedState (confirms event processing started). + assert linked_state.poll_for_content(linked_to) == shared_token + + # Modify SharedNotes.note on the shared token directly via the API. + api_url = f"{get_config().api_url}/api/set-counter/{shared_token}/99" + response = httpx.get(api_url) + assert response.status_code == 200 + + # Verify the linked note appears on the page (direct SharedNotes binding). + note = tab.find_element(By.ID, "shared-note") + assert linked_state.poll_for_content(note, exp_not_equal="") == "counter set to 99" + + # Now trigger PrivateState.fetch_shared_note — this calls get_state(SharedNotes) + # from an unrelated state handler. The returned instance must be the + # *linked* SharedNotes (with note="counter set to 99"), not the private copy. + tab.find_element(By.ID, "fetch-note-button").click() + fetched = tab.find_element(By.ID, "fetched-note") + assert ( + linked_state.poll_for_content(fetched, exp_not_equal="") == "counter set to 99" + )