Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from starlette.requests import ClientDisconnect, Request
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.staticfiles import StaticFiles
from typing_extensions import Unpack

from reflex import constants
from reflex.admin import AdminDash
Expand Down Expand Up @@ -83,6 +84,7 @@
get_hydrate_event,
noop,
)
from reflex.istate.manager import StateModificationContext
from reflex.istate.proxy import StateProxy
from reflex.page import DECORATED_PAGES
from reflex.route import (
Expand Down Expand Up @@ -1571,6 +1573,7 @@ async def modify_state(
token: str,
background: bool = False,
previous_dirty_vars: dict[str, set[str]] | None = None,
**context: Unpack[StateModificationContext],
) -> AsyncIterator[BaseState]:
"""Modify the state out of band.

Expand All @@ -1591,7 +1594,7 @@ async def modify_state(

# Get exclusive access to the state.
async with self.state_manager.modify_state_with_links(
token, previous_dirty_vars=previous_dirty_vars
token, previous_dirty_vars=previous_dirty_vars, **context
) as state:
# No other event handler can modify the state while in this context.
yield state
Expand Down Expand Up @@ -1624,7 +1627,7 @@ def _process_background(
if not handler.is_background:
return None

substate = StateProxy(substate)
substate = StateProxy(substate, event)

async def _coro():
"""Coroutine to process the event and emit updates inside an asyncio.Task.
Expand Down Expand Up @@ -2042,7 +2045,7 @@ async def _ndjson_updates():
"""
# Process the event.
async with app.state_manager.modify_state_with_links(
event.substate_token
event.substate_token, event=event
) as state:
async for update in state._process(event):
# Postprocess the event.
Expand Down
2 changes: 1 addition & 1 deletion reflex/istate/manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class StateModificationContext(TypedDict, total=False):
"""The context for modifying state."""

event: ReadOnly[Event]
event: ReadOnly[Event | None]


EmptyContext = StateModificationContext()
Expand Down
22 changes: 16 additions & 6 deletions reflex/istate/manager/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,14 @@ async def set_state(
# Check that we're holding the lock.
if (
lock_id is not None
and await self.redis.get(self._lock_key(token)) != lock_id
and (existing_lock_id := await self.redis.get(self._lock_key(token)))
!= lock_id
):
msg = (
f"Lock expired for token {token} while processing. Consider increasing "
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
"or use `@rx.event(background=True)` decorator for long-running tasks."
"or use `@rx.event(background=True)` decorator for long-running tasks. "
f"Current lock id: {existing_lock_id!r}, expected lock id: {lock_id!r}."
+ (
f" Happened in event: {event.name}"
if (event := context.get("event")) is not None
Expand Down Expand Up @@ -440,9 +442,10 @@ async def _try_modify_state(
Yields:
The state for the token or None if we couldn't get the lock.
"""
event_name = event.name if (event := context.get("event")) is not None else None
if not self._oplock_enabled:
# OpLock is disabled, get a fresh lock, write, and release.
async with self._lock(token) as lock_id:
async with self._lock(token, event_name=event_name) as lock_id:
state = await self.get_state(token)
yield state
await self.set_state(token, state, lock_id=lock_id, **context)
Expand All @@ -459,7 +462,9 @@ async def _try_modify_state(
client_token, _ = _split_substate_key(token)
lock_held_ctx = contextlib.AsyncExitStack()
try:
lock_id = await lock_held_ctx.enter_async_context(self._lock(token))
lock_id = await lock_held_ctx.enter_async_context(
self._lock(token, event_name=event_name)
)
except OplockFound:
# While waiting for the lock, another process has acquired it, but we can piggy back.
pass
Expand Down Expand Up @@ -1000,11 +1005,14 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
)

@contextlib.asynccontextmanager
async def _lock(self, token: str):
async def _lock(
self, token: str, event_name: str | None = None
) -> AsyncIterator[bytes]:
"""Obtain a redis lock for a token.

Args:
token: The token to obtain a lock for.
event_name: The name of the event associated with the lock.

Yields:
The ID of the lock (to be passed to set_state).
Expand All @@ -1013,7 +1021,9 @@ async def _lock(self, token: str):
LockExpiredError: If the lock has expired while processing the event.
"""
lock_key = self._lock_key(token)
lock_id = uuid.uuid4().hex.encode()
lock_id = (
f"{event_name}:{uuid.uuid4().hex}" if event_name else uuid.uuid4().hex
).encode()

await self._wait_lock(lock_key, lock_id)
state_is_locked = True
Expand Down
10 changes: 8 additions & 2 deletions reflex/istate/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Self

from reflex.base import Base
from reflex.event import Event
from reflex.utils import prerequisites
from reflex.utils.exceptions import ImmutableStateError
from reflex.utils.serializers import can_serialize, serialize, serializer
Expand Down Expand Up @@ -59,6 +60,7 @@ async def bg_increment(self):
def __init__(
self,
state_instance: BaseState,
event: Event | None = None,
parent_state_proxy: StateProxy | None = None,
):
"""Create a proxy for a state instance.
Expand All @@ -69,11 +71,13 @@ def __init__(

Args:
state_instance: The state instance to proxy.
event: The event associated with the state modification context.
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
"""
from reflex.state import _substate_key

super().__init__(state_instance)
self._self_event = event
self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_substate_token = _substate_key(
Expand Down Expand Up @@ -136,7 +140,7 @@ async def __aenter__(self) -> Self:
try:
self._self_actx_lock_holder = current_task
self._self_actx = self._self_app.modify_state(
token=self._self_substate_token, background=True
token=self._self_substate_token, background=True, event=self._self_event
)
mutable_state = await self._self_actx.__aenter__()
self._self_mutable = True
Expand Down Expand Up @@ -294,7 +298,9 @@ async def get_state(self, state_cls: type[T_STATE]) -> T_STATE:
)
raise ImmutableStateError(msg)
return type(self)(
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
await self.__wrapped__.get_state(state_cls),
event=self._self_event,
parent_state_proxy=self,
) # pyright: ignore [reportReturnType]

async def _as_state_update(self, *args, **kwargs) -> StateUpdate:
Expand Down
2 changes: 1 addition & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,7 +1749,7 @@ async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:

# For background tasks, proxy the state.
if handler.is_background:
substate = StateProxy(substate)
substate = StateProxy(substate, event)

# Run the event generator and yield state updates.
async for update in self._process_event(
Expand Down
Loading