Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import traceback
import urllib.parse
from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence
from contextvars import Token
from datetime import datetime
from itertools import chain
from pathlib import Path
Expand All @@ -31,6 +32,7 @@
evaluate_style_namespaces,
)
from reflex_base.config import get_config
from reflex_base.context.base import BaseContext
from reflex_base.environment import ExecutorType, environment
from reflex_base.event import (
_EVENT_FIELDS,
Expand All @@ -40,6 +42,7 @@
IndividualEventType,
noop,
)
from reflex_base.event.context import EventContext
from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor
from reflex_base.registry import RegistrationContext
from reflex_base.utils import console
Expand Down Expand Up @@ -574,8 +577,65 @@ async def modified_send(message: Message):
# Ensure the event processor starts and stops with the server.
self.register_lifespan_task(self._setup_event_processor)

def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp:
"""Ensure the RegistrationContext is attached to the ASGI app.
def _set_contexts_internal(self) -> dict[type[BaseContext], Token]:
"""Set Reflex contexts if not already present, returning reset tokens.

Returns:
A dict mapping context class to the contextvars Token for each
context that was set. Empty if all contexts were already present.
"""
tokens: dict[type[BaseContext], Token] = {}

if self._registration_context is not None:
try:
RegistrationContext.get()
except LookupError:
tokens[RegistrationContext] = RegistrationContext.set(
self._registration_context
)

if (
self._event_processor is not None
and self._event_processor._root_context is not None
):
try:
EventContext.get()
except LookupError:
tokens[EventContext] = EventContext.set(
self._event_processor._root_context
)

return tokens

def set_contexts(self) -> contextlib.AbstractContextManager:
"""Set Reflex contexts needed for state and event processing.

Pushes RegistrationContext and EventContext into the current
contextvars scope, but only if they are not already set.

Can be used as a context manager::

with app.set_contexts():
async with app.modify_state(token) as state:
...

Returns:
A context manager that resets any contexts that were set on exit.
"""
tokens = self._set_contexts_internal()
if not tokens:
return contextlib.nullcontext()
stack = contextlib.ExitStack()
for ctx_cls, tok in tokens.items():
stack.callback(ctx_cls.reset, tok)
return stack

def _context_middleware(self, app: ASGIApp) -> ASGIApp:
"""Ensure Reflex contexts are attached for each ASGI request.

Many ASGI servers start each request with a fresh contextvars scope,
so this middleware re-applies the RegistrationContext and EventContext
that are needed for Reflex state and event processing.

Args:
app: The ASGI app to attach the middleware to.
Expand All @@ -584,14 +644,11 @@ def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp:
The ASGI app with the middleware attached.
"""

async def registration_context_middleware(
scope: Scope, receive: Receive, send: Send
):
if self._registration_context is not None:
RegistrationContext.set(self._registration_context)
async def context_middleware(scope: Scope, receive: Receive, send: Send):
self._set_contexts_internal()
await app(scope, receive, send)

return registration_context_middleware
return context_middleware

@contextlib.asynccontextmanager
async def _setup_event_processor(self) -> AsyncIterator[None]:
Expand Down Expand Up @@ -669,10 +726,10 @@ def __call__(self) -> ASGIApp:
asgi_app = api_transformer(asgi_app)

top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks)
# Make sure the RegistrationContext is attached.
# Make sure Reflex contexts are attached for each request.
top_asgi_app.mount(
"",
self._registration_context_middleware(asgi_app),
self._context_middleware(asgi_app),
)
App._add_cors(top_asgi_app)
return top_asgi_app
Expand Down Expand Up @@ -1607,20 +1664,22 @@ async def modify_state(
if isinstance(token, str):
token = BaseStateToken.from_legacy_token(token, root_state=self._state)

# Get exclusive access to the state.
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars, **context
) as state:
# No other event handler can modify the state while in this context.
yield state
delta = await state._get_resolved_delta()
state._clean()
if delta:
# When the frontend vars are modified emit the delta to the frontend.
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
token=token.ident,
)
# Ensure Reflex contexts are available (e.g. when called from an API route).
with self.set_contexts():
# Get exclusive access to the state.
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars, **context
) as state:
# No other event handler can modify the state while in this context.
yield state
delta = await state._get_resolved_delta()
state._clean()
if delta:
# When the frontend vars are modified emit the delta to the frontend.
await self.event_namespace.emit_update(
update=StateUpdate(delta=delta),
token=token.ident,
)

def _validate_exception_handlers(self):
"""Validate the custom event exception handlers for front- and backend.
Expand Down
33 changes: 33 additions & 0 deletions reflex/istate/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,13 @@ async def _modify_linked_states(
for token in linked_state._linked_from
if token != self.router.session.client_token
)
# When modifying a shared token directly (empty _reflex_internal_links),
# the held locks will be empty. Check SharedState substates for linked
# clients that need to be notified.
if not self._reflex_internal_links:
self._collect_shared_token_updates(
affected_tokens, current_dirty_vars
)
finally:
self._exit_stack = None

Expand All @@ -397,6 +404,32 @@ async def _modify_linked_states(
state_type=type(self),
)

def _collect_shared_token_updates(
self,
affected_tokens: set[str],
current_dirty_vars: dict[str, set[str]],
) -> None:
"""Collect dirty vars and linked clients from SharedState substates.

When a shared state is modified directly by its shared token (rather than
through a private client token), the held locks are empty so the normal
collection loop above finds nothing. This method checks the SharedState
substates directly for linked clients that need to be notified.

Args:
affected_tokens: Set to update with client tokens that need notification.
current_dirty_vars: Dict to update with dirty var mappings per state.
"""
for substate in self.substates.values():
if not isinstance(substate, SharedState) or not substate._linked_from:
continue
if substate._previous_dirty_vars:
current_dirty_vars[substate.get_full_name()] = set(
substate._previous_dirty_vars
)
if substate._get_was_touched() or substate._previous_dirty_vars:
affected_tokens.update(substate._linked_from)
Comment on lines +426 to +431
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Truthiness check diverges from the existing is not None pattern

The existing collection loop (lines 376–383) uses is not None to decide whether to populate current_dirty_vars and affected_tokens, treating an empty set() the same as a populated one. This new method uses a truthiness check, so an empty _previous_dirty_vars = set() (the field's initialiser value) skips the current_dirty_vars entry even when _get_was_touched() is True.

In practice _previous_dirty_vars will be non-empty whenever vars were actually modified, so this is unlikely to surface as a bug; but aligning with the is not None guard used in the sibling loop keeps the two paths consistent and easier to reason about.

Suggested change
if substate._previous_dirty_vars:
current_dirty_vars[substate.get_full_name()] = set(
substate._previous_dirty_vars
)
if substate._get_was_touched() or substate._previous_dirty_vars:
affected_tokens.update(substate._linked_from)
if substate._previous_dirty_vars is not None:
current_dirty_vars[substate.get_full_name()] = set(
substate._previous_dirty_vars
)
if substate._get_was_touched() or substate._previous_dirty_vars is not None:
affected_tokens.update(substate._linked_from)



class SharedState(SharedStateBaseInternal, mixin=True):
"""Mixin for defining new shared states."""
Expand Down
116 changes: 114 additions & 2 deletions tests/integration/test_linked_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import uuid
from collections.abc import Callable, Generator

import httpx
import pytest
from reflex_base.config import get_config
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.remote.webelement import WebElement

from reflex.testing import AppHarness, WebDriver

Expand Down Expand Up @@ -63,6 +66,17 @@ async def handle_submit(self, form_data: dict[str, Any]):
if "token" in form_data:
await self.link_to(form_data["token"])

class SharedNotes(rx.SharedState):
"""A second SharedState to test multi-SharedState propagation."""

note: str = ""

@rx.event
async def on_load_link_default(self):
await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue]
if not self.note:
self.note = "linked"

class PrivateState(rx.State):
@rx.var
async def greeting(self) -> str:
Expand Down Expand Up @@ -140,10 +154,32 @@ def index() -> rx.Component:
on_click=SharedState.link_to_and_increment,
id="link-increment-button",
),
rx.text(SharedNotes.note, id="shared-note"),
)

app = rx.App()
app.add_page(index, route="/room/[room]", on_load=SharedState.on_load_link_default)
from fastapi import FastAPI

api = FastAPI()

@api.get("/api/set-counter/{shared_token}/{value}")
async def set_counter_api(shared_token: str, value: int):
"""Modify shared state by its shared token from an API route."""
from reflex.istate.manager.token import BaseStateToken

async with app.modify_state(
BaseStateToken(ident=shared_token, cls=SharedState),
) as state:
ss = await state.get_state(SharedState)
ss.counter = value
notes = await state.get_state(SharedNotes)
notes.note = f"counter set to {value}"

app = rx.App(api_transformer=api)
app.add_page(
index,
route="/room/[room]",
on_load=[SharedState.on_load_link_default, SharedNotes.on_load_link_default],
)
app.add_page(index)


Expand Down Expand Up @@ -386,3 +422,79 @@ def test_linked_state(
# Link to a new state and increment the counter in the same event
tab1.find_element(By.ID, "link-increment-button").click()
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="3") == "1"


def _open_linked_tab(
harness: AppHarness,
tab_factory: Callable[[], WebDriver],
shared_token: str,
) -> tuple[WebElement, WebElement]:
"""Open a new tab linked to a shared token and return key elements.

Args:
harness: The running AppHarness.
tab_factory: Factory to create WebDriver instances.
shared_token: The shared token to link to via on_load.

Returns:
Tuple of (counter_button, note_element).
"""
tab = tab_factory()
tab.get(f"{harness.frontend_url}room/{shared_token}")
ss = utils.SessionStorage(tab)
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"
counter_button = AppHarness._poll_for(
lambda: tab.find_element(By.ID, "counter-button")
)
assert counter_button
assert harness.poll_for_content(counter_button) == "0"
note = tab.find_element(By.ID, "shared-note")
# Wait for SharedNotes.on_load_link_default to complete (sets note="linked").
# This ensures both on_load handlers have finished before returning, since
# SharedNotes' handler runs after SharedState's and events are sequential.
assert harness.poll_for_content(note) == "linked"
return counter_button, note


def test_modify_shared_state_by_shared_token(
linked_state: AppHarness,
tab_factory: Callable[[], WebDriver],
):
"""Test that modifying shared state by shared token propagates to all linked clients.

This exercises the use case of modifying shared state from an API route
where only the shared token is known (no private client token).

Args:
linked_state: harness for LinkedStateApp.
tab_factory: factory to create WebDriver instances.
"""
assert linked_state.app_instance is not None

shared_token = f"api-test-{uuid.uuid4()}"

# Open two tabs linked to the same shared token via on_load
counter_button_1, note_1 = _open_linked_tab(linked_state, tab_factory, shared_token)
counter_button_2, note_2 = _open_linked_tab(linked_state, tab_factory, shared_token)

# Modify both shared states by shared token via API route
api_url = f"{get_config().api_url}/api/set-counter/{shared_token}/42"
response = httpx.get(api_url)
assert response.status_code == 200

# Both tabs should see updates to both SharedState and SharedNotes
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="0") == "42"
assert linked_state.poll_for_content(counter_button_2, exp_not_equal="0") == "42"
assert (
linked_state.poll_for_content(note_1, exp_not_equal="linked")
== "counter set to 42"
)
assert (
linked_state.poll_for_content(note_2, exp_not_equal="linked")
== "counter set to 42"
)

# After the API-driven update, normal event handlers should still work
counter_button_1.click()
assert linked_state.poll_for_content(counter_button_1, exp_not_equal="42") == "43"
assert linked_state.poll_for_content(counter_button_2, exp_not_equal="42") == "43"
Loading
Loading