Skip to content

Commit b977e62

Browse files
authored
ENG-9350: App.modify_state can directly modify SharedState tokens (#6336)
* ENG-9350: App.modify_state can directly modify SharedState tokens When a SharedState is modified directly by its shared token (e.g. from an API route webhook), propagate dirty vars to all clients linked to that shared state — matching the behavior that already exists when modifying shared state through a private client token. - Add _collect_shared_token_updates on SharedStateBaseInternal to detect the shared-token case inside _modify_linked_states and propagate to linked clients via the existing _do_update_other_tokens mechanism - Add App.set_contexts / App._set_contexts_internal to centralize pushing RegistrationContext and EventContext into the current contextvars scope, replacing the old _registration_context_middleware - App.modify_state now calls set_contexts so out-of-band callers (API routes, webhooks) get the necessary contexts automatically - Integration test: API endpoint modifies two SharedState subclasses by shared token, asserts both propagate to two linked browser tabs, and verifies normal event handlers still work afterward - Unit tests for set_contexts covering all combinations of pre-existing / absent contexts, no-event-processor, and reset-on-exit behavior Closes #6335 * deal with race condition due to slower execution in CI environment set an initial value for the shared state to provide affirmative proof that the state was linked _before_ the API call was made. if only one state was linked when the HTTP request was made, then only one state would be updated by such request resulting in the observed behavior. * followup from previous commit but actually set the value on the newly linked state, not the old private state. derpy easy to make mistake with this _link_to mechanism =/ * Add test case for fetching previously unfetched linked state Using BaseState.get_state(...) with a linked state that _is_ linked, but not currently cached should fetch the linked state instance, not the private state instance. * Handle `.get_state` when directly modifying a shared token In an `app.modify_state(...)` context where the passed token is a shared token, handle the case where `.get_state(...)` is used to fetch _another_ shared state with the same token as the original shared state token. Basically if some code is modifying "tokenA" for "Shared1" and calls `shared_1.get_state(Shared2)`, then the retrieved state will also be associated with "tokenA", as expected. * recursively _collect_shared_token_updates Starting from the SharedStateBaseInternal, recurse into each pre-cached instance and collect affected tokens and dirty vars. * Fixturify the new test_app set_contexts test case * remove superfluous `run_isolated` fixture just use `isolated_context.run` for simplicity
1 parent 3bc86bb commit b977e62

5 files changed

Lines changed: 491 additions & 36 deletions

File tree

reflex/app.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import traceback
1717
import urllib.parse
1818
from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence
19+
from contextvars import Token
1920
from datetime import datetime
2021
from itertools import chain
2122
from pathlib import Path
@@ -31,6 +32,7 @@
3132
evaluate_style_namespaces,
3233
)
3334
from reflex_base.config import get_config
35+
from reflex_base.context.base import BaseContext
3436
from reflex_base.environment import ExecutorType, environment
3537
from reflex_base.event import (
3638
_EVENT_FIELDS,
@@ -40,6 +42,7 @@
4042
IndividualEventType,
4143
noop,
4244
)
45+
from reflex_base.event.context import EventContext
4346
from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor
4447
from reflex_base.registry import RegistrationContext
4548
from reflex_base.utils import console
@@ -577,8 +580,65 @@ async def modified_send(message: Message):
577580
# Ensure the event processor starts and stops with the server.
578581
self.register_lifespan_task(self._setup_event_processor)
579582

580-
def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp:
581-
"""Ensure the RegistrationContext is attached to the ASGI app.
583+
def _set_contexts_internal(self) -> dict[type[BaseContext], Token]:
584+
"""Set Reflex contexts if not already present, returning reset tokens.
585+
586+
Returns:
587+
A dict mapping context class to the contextvars Token for each
588+
context that was set. Empty if all contexts were already present.
589+
"""
590+
tokens: dict[type[BaseContext], Token] = {}
591+
592+
if self._registration_context is not None:
593+
try:
594+
RegistrationContext.get()
595+
except LookupError:
596+
tokens[RegistrationContext] = RegistrationContext.set(
597+
self._registration_context
598+
)
599+
600+
if (
601+
self._event_processor is not None
602+
and self._event_processor._root_context is not None
603+
):
604+
try:
605+
EventContext.get()
606+
except LookupError:
607+
tokens[EventContext] = EventContext.set(
608+
self._event_processor._root_context
609+
)
610+
611+
return tokens
612+
613+
def set_contexts(self) -> contextlib.AbstractContextManager:
614+
"""Set Reflex contexts needed for state and event processing.
615+
616+
Pushes RegistrationContext and EventContext into the current
617+
contextvars scope, but only if they are not already set.
618+
619+
Can be used as a context manager::
620+
621+
with app.set_contexts():
622+
async with app.modify_state(token) as state:
623+
...
624+
625+
Returns:
626+
A context manager that resets any contexts that were set on exit.
627+
"""
628+
tokens = self._set_contexts_internal()
629+
if not tokens:
630+
return contextlib.nullcontext()
631+
stack = contextlib.ExitStack()
632+
for ctx_cls, tok in tokens.items():
633+
stack.callback(ctx_cls.reset, tok)
634+
return stack
635+
636+
def _context_middleware(self, app: ASGIApp) -> ASGIApp:
637+
"""Ensure Reflex contexts are attached for each ASGI request.
638+
639+
Many ASGI servers start each request with a fresh contextvars scope,
640+
so this middleware re-applies the RegistrationContext and EventContext
641+
that are needed for Reflex state and event processing.
582642
583643
Args:
584644
app: The ASGI app to attach the middleware to.
@@ -587,14 +647,11 @@ def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp:
587647
The ASGI app with the middleware attached.
588648
"""
589649

590-
async def registration_context_middleware(
591-
scope: Scope, receive: Receive, send: Send
592-
):
593-
if self._registration_context is not None:
594-
RegistrationContext.set(self._registration_context)
650+
async def context_middleware(scope: Scope, receive: Receive, send: Send):
651+
self._set_contexts_internal()
595652
await app(scope, receive, send)
596653

597-
return registration_context_middleware
654+
return context_middleware
598655

599656
@contextlib.asynccontextmanager
600657
async def _setup_event_processor(self) -> AsyncIterator[None]:
@@ -672,10 +729,10 @@ def __call__(self) -> ASGIApp:
672729
asgi_app = api_transformer(asgi_app)
673730

674731
top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
675-
# Make sure the RegistrationContext is attached.
732+
# Make sure Reflex contexts are attached for each request.
676733
top_asgi_app.mount(
677734
"",
678-
self._registration_context_middleware(asgi_app),
735+
self._context_middleware(asgi_app),
679736
)
680737
App._add_cors(top_asgi_app)
681738
return top_asgi_app
@@ -1615,20 +1672,22 @@ async def modify_state(
16151672
if isinstance(token, str):
16161673
token = BaseStateToken.from_legacy_token(token, root_state=self._state)
16171674

1618-
# Get exclusive access to the state.
1619-
async with self.state_manager.modify_state_with_links(
1620-
token, previous_dirty_vars=previous_dirty_vars, **context
1621-
) as state:
1622-
# No other event handler can modify the state while in this context.
1623-
yield state
1624-
delta = await state._get_resolved_delta()
1625-
state._clean()
1626-
if delta:
1627-
# When the frontend vars are modified emit the delta to the frontend.
1628-
await self.event_namespace.emit_update(
1629-
update=StateUpdate(delta=delta),
1630-
token=token.ident,
1631-
)
1675+
# Ensure Reflex contexts are available (e.g. when called from an API route).
1676+
with self.set_contexts():
1677+
# Get exclusive access to the state.
1678+
async with self.state_manager.modify_state_with_links(
1679+
token, previous_dirty_vars=previous_dirty_vars, **context
1680+
) as state:
1681+
# No other event handler can modify the state while in this context.
1682+
yield state
1683+
delta = await state._get_resolved_delta()
1684+
state._clean()
1685+
if delta:
1686+
# When the frontend vars are modified emit the delta to the frontend.
1687+
await self.event_namespace.emit_update(
1688+
update=StateUpdate(delta=delta),
1689+
token=token.ident,
1690+
)
16321691

16331692
def _validate_exception_handlers(self):
16341693
"""Validate the custom event exception handlers for front- and backend.

reflex/istate/shared.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,38 @@ def _rehydrate(self):
173173
State.set_is_hydrated(True),
174174
]
175175

176+
async def _resolve_linked_state(
177+
self, state_cls: type["BaseState"], linked_token: str
178+
) -> "BaseState":
179+
"""Load and patch a linked state that was not pre-loaded in the tree.
180+
181+
Called by State._get_state_from_redis when a state in
182+
_reflex_internal_links is not yet in the cache. This loads the
183+
private copy into the tree first, then patches the linked version
184+
on top of it via _internal_patch_linked_state.
185+
186+
Args:
187+
state_cls: The shared state class to resolve.
188+
linked_token: The shared token the state is linked to.
189+
190+
Returns:
191+
The linked state instance, patched into the current tree.
192+
193+
Raises:
194+
ReflexRuntimeError: If the resolved state is not a SharedState.
195+
"""
196+
root_state = self._get_root_state()
197+
198+
# Load the private copy into the tree so _internal_patch_linked_state
199+
# has an original to swap out (needed for unlink / restore).
200+
original_state = await BaseState._get_state_from_redis(root_state, state_cls)
201+
202+
if isinstance(original_state, SharedStateBaseInternal):
203+
return await original_state._internal_patch_linked_state(linked_token)
204+
205+
msg = f"Failed to resolve linked state {state_cls.get_full_name()} for token {linked_token}: state does not inherit from rx.SharedState"
206+
raise ReflexRuntimeError(msg)
207+
176208
async def _link_to(self, token: str) -> Self:
177209
"""Link this shared state to a token.
178210
@@ -194,7 +226,7 @@ async def _link_to(self, token: str) -> Self:
194226
raise ReflexRuntimeError(msg)
195227
if not isinstance(self, SharedState):
196228
msg = "Can only link SharedState instances."
197-
raise RuntimeError(msg)
229+
raise ReflexRuntimeError(msg)
198230
if self._linked_to == token:
199231
return self # already linked to this token
200232
if self._linked_to and self._linked_to != token:
@@ -280,6 +312,18 @@ async def _internal_patch_linked_state(
280312
BaseStateToken(ident=token, cls=type(self))
281313
)
282314
)
315+
# Set client_token on the linked root so that subsequent get_state
316+
# calls when directly modifying a linked token will load the
317+
# associated instance.
318+
if linked_root_state.router.session.client_token != token:
319+
import dataclasses as dc
320+
321+
linked_root_state.router = dc.replace(
322+
linked_root_state.router,
323+
session=dc.replace(
324+
linked_root_state.router.session, client_token=token
325+
),
326+
)
283327
self._held_locks.setdefault(token, {})
284328
else:
285329
linked_root_state = await get_state_manager().get_state(
@@ -386,6 +430,22 @@ async def _modify_linked_states(
386430
for token in linked_state._linked_from
387431
if token != self.router.session.client_token
388432
)
433+
# When modifying a shared token directly (empty _reflex_internal_links),
434+
# the held locks will be empty. Check SharedState substates for linked
435+
# clients that need to be notified.
436+
if not self._reflex_internal_links:
437+
shared_state_base_internal = await self.get_state(
438+
SharedStateBaseInternal
439+
)
440+
if not isinstance(
441+
shared_state_base_internal, SharedStateBaseInternal
442+
):
443+
msg = "Expected SharedStateBaseInternal in substates."
444+
raise ReflexRuntimeError(msg)
445+
# Collect affected tokens from all potentially linked states.
446+
shared_state_base_internal._collect_shared_token_updates(
447+
affected_tokens, current_dirty_vars
448+
)
389449
finally:
390450
self._exit_stack = None
391451

@@ -397,6 +457,34 @@ async def _modify_linked_states(
397457
state_type=type(self),
398458
)
399459

460+
def _collect_shared_token_updates(
461+
self,
462+
affected_tokens: set[str],
463+
current_dirty_vars: dict[str, set[str]],
464+
) -> None:
465+
"""Recursively collect dirty vars and linked clients from SharedState substates.
466+
467+
When a shared state is modified directly by its shared token (rather than
468+
through a private client token), the held locks are empty so the normal
469+
collection loop above finds nothing. This method recursively checks
470+
SharedState substates for linked clients that need to be notified.
471+
472+
Args:
473+
affected_tokens: Set to update with client tokens that need notification.
474+
current_dirty_vars: Dict to update with dirty var mappings per state.
475+
"""
476+
for substate in self.substates.values():
477+
if not isinstance(substate, SharedState):
478+
continue
479+
if substate._linked_from:
480+
if substate._previous_dirty_vars:
481+
current_dirty_vars[substate.get_full_name()] = set(
482+
substate._previous_dirty_vars
483+
)
484+
if substate._get_was_touched() or substate._previous_dirty_vars:
485+
affected_tokens.update(substate._linked_from)
486+
substate._collect_shared_token_updates(affected_tokens, current_dirty_vars)
487+
400488

401489
class SharedState(SharedStateBaseInternal, mixin=True):
402490
"""Mixin for defining new shared states."""

reflex/state.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,6 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE:
21462146
Returns:
21472147
The instance of state_cls associated with this state's client_token.
21482148
"""
2149-
state_instance = await super()._get_state_from_redis(state_cls)
21502149
if (
21512150
self._reflex_internal_links
21522151
and (
@@ -2155,15 +2154,12 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE:
21552154
)
21562155
)
21572156
is not None
2158-
and (
2159-
internal_patch_linked_state := getattr(
2160-
state_instance, "_internal_patch_linked_state", None
2161-
)
2162-
)
2163-
is not None
21642157
):
2165-
return await internal_patch_linked_state(linked_token)
2166-
return state_instance
2158+
from reflex.istate.shared import SharedStateBaseInternal
2159+
2160+
shared_base = await self.get_state(SharedStateBaseInternal)
2161+
return await shared_base._resolve_linked_state(state_cls, linked_token) # type: ignore[return-value]
2162+
return await super()._get_state_from_redis(state_cls)
21672163

21682164
@event
21692165
async def hydrate(self) -> None:

0 commit comments

Comments
 (0)