Skip to content

Commit 870bdde

Browse files
committed
provide event for redis state expiry
1 parent d45a1bb commit 870bdde

5 files changed

Lines changed: 40 additions & 14 deletions

File tree

reflex/app.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,15 @@
2929
from pathlib import Path
3030
from timeit import default_timer as timer
3131
from types import SimpleNamespace
32-
from typing import TYPE_CHECKING, Any, BinaryIO, ParamSpec, get_args, get_type_hints
32+
from typing import (
33+
TYPE_CHECKING,
34+
Any,
35+
BinaryIO,
36+
ParamSpec,
37+
Unpack,
38+
get_args,
39+
get_type_hints,
40+
)
3341

3442
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
3543
from socketio import ASGIApp as EngineIOApp
@@ -83,6 +91,7 @@
8391
get_hydrate_event,
8492
noop,
8593
)
94+
from reflex.istate.manager import StateModificationContext
8695
from reflex.istate.proxy import StateProxy
8796
from reflex.page import DECORATED_PAGES
8897
from reflex.route import (
@@ -1571,6 +1580,7 @@ async def modify_state(
15711580
token: str,
15721581
background: bool = False,
15731582
previous_dirty_vars: dict[str, set[str]] | None = None,
1583+
**context: Unpack[StateModificationContext],
15741584
) -> AsyncIterator[BaseState]:
15751585
"""Modify the state out of band.
15761586
@@ -1591,7 +1601,7 @@ async def modify_state(
15911601

15921602
# Get exclusive access to the state.
15931603
async with self.state_manager.modify_state_with_links(
1594-
token, previous_dirty_vars=previous_dirty_vars
1604+
token, previous_dirty_vars=previous_dirty_vars, **context
15951605
) as state:
15961606
# No other event handler can modify the state while in this context.
15971607
yield state
@@ -1624,7 +1634,7 @@ def _process_background(
16241634
if not handler.is_background:
16251635
return None
16261636

1627-
substate = StateProxy(substate)
1637+
substate = StateProxy(substate, event)
16281638

16291639
async def _coro():
16301640
"""Coroutine to process the event and emit updates inside an asyncio.Task.
@@ -2042,7 +2052,7 @@ async def _ndjson_updates():
20422052
"""
20432053
# Process the event.
20442054
async with app.state_manager.modify_state_with_links(
2045-
event.substate_token
2055+
event.substate_token, event=event
20462056
) as state:
20472057
async for update in state._process(event):
20482058
# Postprocess the event.

reflex/istate/manager/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class StateModificationContext(TypedDict, total=False):
2020
"""The context for modifying state."""
2121

22-
event: ReadOnly[Event]
22+
event: ReadOnly[Event | None]
2323

2424

2525
EmptyContext = StateModificationContext()

reflex/istate/manager/redis.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,14 @@ async def set_state(
363363
# Check that we're holding the lock.
364364
if (
365365
lock_id is not None
366-
and await self.redis.get(self._lock_key(token)) != lock_id
366+
and (existing_lock_id := await self.redis.get(self._lock_key(token)))
367+
!= lock_id
367368
):
368369
msg = (
369370
f"Lock expired for token {token} while processing. Consider increasing "
370371
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
371-
"or use `@rx.event(background=True)` decorator for long-running tasks."
372+
"or use `@rx.event(background=True)` decorator for long-running tasks. "
373+
f"Current lock id: {existing_lock_id!r}, expected lock id: {lock_id!r}."
372374
+ (
373375
f" Happened in event: {event.name}"
374376
if (event := context.get("event")) is not None
@@ -440,9 +442,10 @@ async def _try_modify_state(
440442
Yields:
441443
The state for the token or None if we couldn't get the lock.
442444
"""
445+
event_name = event.name if (event := context.get("event")) is not None else None
443446
if not self._oplock_enabled:
444447
# OpLock is disabled, get a fresh lock, write, and release.
445-
async with self._lock(token) as lock_id:
448+
async with self._lock(token, event_name=event_name) as lock_id:
446449
state = await self.get_state(token)
447450
yield state
448451
await self.set_state(token, state, lock_id=lock_id, **context)
@@ -459,7 +462,9 @@ async def _try_modify_state(
459462
client_token, _ = _split_substate_key(token)
460463
lock_held_ctx = contextlib.AsyncExitStack()
461464
try:
462-
lock_id = await lock_held_ctx.enter_async_context(self._lock(token))
465+
lock_id = await lock_held_ctx.enter_async_context(
466+
self._lock(token, event_name=event_name)
467+
)
463468
except OplockFound:
464469
# While waiting for the lock, another process has acquired it, but we can piggy back.
465470
pass
@@ -1000,11 +1005,14 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
10001005
)
10011006

10021007
@contextlib.asynccontextmanager
1003-
async def _lock(self, token: str):
1008+
async def _lock(
1009+
self, token: str, event_name: str | None = None
1010+
) -> AsyncIterator[bytes]:
10041011
"""Obtain a redis lock for a token.
10051012
10061013
Args:
10071014
token: The token to obtain a lock for.
1015+
event_name: The name of the event associated with the lock.
10081016
10091017
Yields:
10101018
The ID of the lock (to be passed to set_state).
@@ -1013,7 +1021,9 @@ async def _lock(self, token: str):
10131021
LockExpiredError: If the lock has expired while processing the event.
10141022
"""
10151023
lock_key = self._lock_key(token)
1016-
lock_id = uuid.uuid4().hex.encode()
1024+
lock_id = (
1025+
f"{event_name}_{uuid.uuid4().hex}" if event_name else uuid.uuid4().hex
1026+
).encode()
10171027

10181028
await self._wait_lock(lock_key, lock_id)
10191029
state_is_locked = True

reflex/istate/proxy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing_extensions import Self
1919

2020
from reflex.base import Base
21+
from reflex.event import Event
2122
from reflex.utils import prerequisites
2223
from reflex.utils.exceptions import ImmutableStateError
2324
from reflex.utils.serializers import can_serialize, serialize, serializer
@@ -59,6 +60,7 @@ async def bg_increment(self):
5960
def __init__(
6061
self,
6162
state_instance: BaseState,
63+
event: Event | None = None,
6264
parent_state_proxy: StateProxy | None = None,
6365
):
6466
"""Create a proxy for a state instance.
@@ -69,11 +71,13 @@ def __init__(
6971
7072
Args:
7173
state_instance: The state instance to proxy.
74+
event: The event associated with the state modification context.
7275
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
7376
"""
7477
from reflex.state import _substate_key
7578

7679
super().__init__(state_instance)
80+
self._self_event = event
7781
self._self_app = prerequisites.get_and_validate_app().app
7882
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
7983
self._self_substate_token = _substate_key(
@@ -136,7 +140,7 @@ async def __aenter__(self) -> Self:
136140
try:
137141
self._self_actx_lock_holder = current_task
138142
self._self_actx = self._self_app.modify_state(
139-
token=self._self_substate_token, background=True
143+
token=self._self_substate_token, background=True, event=self._self_event
140144
)
141145
mutable_state = await self._self_actx.__aenter__()
142146
self._self_mutable = True
@@ -294,7 +298,9 @@ async def get_state(self, state_cls: type[T_STATE]) -> T_STATE:
294298
)
295299
raise ImmutableStateError(msg)
296300
return type(self)(
297-
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
301+
await self.__wrapped__.get_state(state_cls),
302+
event=self._self_event,
303+
parent_state_proxy=self,
298304
) # pyright: ignore [reportReturnType]
299305

300306
async def _as_state_update(self, *args, **kwargs) -> StateUpdate:

reflex/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1749,7 +1749,7 @@ async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
17491749

17501750
# For background tasks, proxy the state.
17511751
if handler.is_background:
1752-
substate = StateProxy(substate)
1752+
substate = StateProxy(substate, event)
17531753

17541754
# Run the event generator and yield state updates.
17551755
async for update in self._process_event(

0 commit comments

Comments
 (0)