diff --git a/reflex/app.py b/reflex/app.py index 54682543a7d..d49ae2240fd 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -42,6 +42,7 @@ from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.staticfiles import StaticFiles +from typing_extensions import Unpack from reflex import constants from reflex.admin import AdminDash @@ -83,6 +84,7 @@ get_hydrate_event, noop, ) +from reflex.istate.manager import StateModificationContext from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES from reflex.route import ( @@ -1571,6 +1573,7 @@ async def modify_state( token: str, background: bool = False, previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], ) -> AsyncIterator[BaseState]: """Modify the state out of band. @@ -1591,7 +1594,7 @@ async def modify_state( # Get exclusive access to the state. async with self.state_manager.modify_state_with_links( - token, previous_dirty_vars=previous_dirty_vars + token, previous_dirty_vars=previous_dirty_vars, **context ) as state: # No other event handler can modify the state while in this context. yield state @@ -1624,7 +1627,7 @@ def _process_background( if not handler.is_background: return None - substate = StateProxy(substate) + substate = StateProxy(substate, event) async def _coro(): """Coroutine to process the event and emit updates inside an asyncio.Task. @@ -2042,7 +2045,7 @@ async def _ndjson_updates(): """ # Process the event. async with app.state_manager.modify_state_with_links( - event.substate_token + event.substate_token, event=event ) as state: async for update in state._process(event): # Postprocess the event. diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 1eae7550de3..ba56395e768 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -19,7 +19,7 @@ class StateModificationContext(TypedDict, total=False): """The context for modifying state.""" - event: ReadOnly[Event] + event: ReadOnly[Event | None] EmptyContext = StateModificationContext() diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 63fc90586aa..bbfd8e20ae2 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -363,12 +363,14 @@ async def set_state( # Check that we're holding the lock. if ( lock_id is not None - and await self.redis.get(self._lock_key(token)) != lock_id + and (existing_lock_id := await self.redis.get(self._lock_key(token))) + != lock_id ): msg = ( f"Lock expired for token {token} while processing. Consider increasing " f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) " - "or use `@rx.event(background=True)` decorator for long-running tasks." + "or use `@rx.event(background=True)` decorator for long-running tasks. " + f"Current lock id: {existing_lock_id!r}, expected lock id: {lock_id!r}." + ( f" Happened in event: {event.name}" if (event := context.get("event")) is not None @@ -440,9 +442,10 @@ async def _try_modify_state( Yields: The state for the token or None if we couldn't get the lock. """ + event_name = event.name if (event := context.get("event")) is not None else None if not self._oplock_enabled: # OpLock is disabled, get a fresh lock, write, and release. - async with self._lock(token) as lock_id: + async with self._lock(token, event_name=event_name) as lock_id: state = await self.get_state(token) yield state await self.set_state(token, state, lock_id=lock_id, **context) @@ -459,7 +462,9 @@ async def _try_modify_state( client_token, _ = _split_substate_key(token) lock_held_ctx = contextlib.AsyncExitStack() try: - lock_id = await lock_held_ctx.enter_async_context(self._lock(token)) + lock_id = await lock_held_ctx.enter_async_context( + self._lock(token, event_name=event_name) + ) except OplockFound: # While waiting for the lock, another process has acquired it, but we can piggy back. pass @@ -1000,11 +1005,14 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: ) @contextlib.asynccontextmanager - async def _lock(self, token: str): + async def _lock( + self, token: str, event_name: str | None = None + ) -> AsyncIterator[bytes]: """Obtain a redis lock for a token. Args: token: The token to obtain a lock for. + event_name: The name of the event associated with the lock. Yields: The ID of the lock (to be passed to set_state). @@ -1013,7 +1021,9 @@ async def _lock(self, token: str): LockExpiredError: If the lock has expired while processing the event. """ lock_key = self._lock_key(token) - lock_id = uuid.uuid4().hex.encode() + lock_id = ( + f"{event_name}:{uuid.uuid4().hex}" if event_name else uuid.uuid4().hex + ).encode() await self._wait_lock(lock_key, lock_id) state_is_locked = True diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 90231eeb7fc..7f9cb78394b 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -18,6 +18,7 @@ from typing_extensions import Self from reflex.base import Base +from reflex.event import Event from reflex.utils import prerequisites from reflex.utils.exceptions import ImmutableStateError from reflex.utils.serializers import can_serialize, serialize, serializer @@ -59,6 +60,7 @@ async def bg_increment(self): def __init__( self, state_instance: BaseState, + event: Event | None = None, parent_state_proxy: StateProxy | None = None, ): """Create a proxy for a state instance. @@ -69,11 +71,13 @@ def __init__( Args: state_instance: The state instance to proxy. + event: The event associated with the state modification context. parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ from reflex.state import _substate_key super().__init__(state_instance) + self._self_event = event self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_substate_token = _substate_key( @@ -136,7 +140,7 @@ async def __aenter__(self) -> Self: try: self._self_actx_lock_holder = current_task self._self_actx = self._self_app.modify_state( - token=self._self_substate_token, background=True + token=self._self_substate_token, background=True, event=self._self_event ) mutable_state = await self._self_actx.__aenter__() self._self_mutable = True @@ -294,7 +298,9 @@ async def get_state(self, state_cls: type[T_STATE]) -> T_STATE: ) raise ImmutableStateError(msg) return type(self)( - await self.__wrapped__.get_state(state_cls), parent_state_proxy=self + await self.__wrapped__.get_state(state_cls), + event=self._self_event, + parent_state_proxy=self, ) # pyright: ignore [reportReturnType] async def _as_state_update(self, *args, **kwargs) -> StateUpdate: diff --git a/reflex/state.py b/reflex/state.py index 49c6703088b..535989d8f1b 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1749,7 +1749,7 @@ async def _process(self, event: Event) -> AsyncIterator[StateUpdate]: # For background tasks, proxy the state. if handler.is_background: - substate = StateProxy(substate) + substate = StateProxy(substate, event) # Run the event generator and yield state updates. async for update in self._process_event(