Skip to content

Commit a90568d

Browse files
provide event for redis state expiry (#6194)
* provide event for redis state expiry * sad * use : instead of _ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent e1ba438 commit a90568d

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

reflex/app.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from starlette.requests import ClientDisconnect, Request
4343
from starlette.responses import JSONResponse, Response, StreamingResponse
4444
from starlette.staticfiles import StaticFiles
45+
from typing_extensions import Unpack
4546

4647
from reflex import constants
4748
from reflex.admin import AdminDash
@@ -83,6 +84,7 @@
8384
get_hydrate_event,
8485
noop,
8586
)
87+
from reflex.istate.manager import StateModificationContext
8688
from reflex.istate.proxy import StateProxy
8789
from reflex.page import DECORATED_PAGES
8890
from reflex.route import (
@@ -1571,6 +1573,7 @@ async def modify_state(
15711573
token: str,
15721574
background: bool = False,
15731575
previous_dirty_vars: dict[str, set[str]] | None = None,
1576+
**context: Unpack[StateModificationContext],
15741577
) -> AsyncIterator[BaseState]:
15751578
"""Modify the state out of band.
15761579
@@ -1591,7 +1594,7 @@ async def modify_state(
15911594

15921595
# Get exclusive access to the state.
15931596
async with self.state_manager.modify_state_with_links(
1594-
token, previous_dirty_vars=previous_dirty_vars
1597+
token, previous_dirty_vars=previous_dirty_vars, **context
15951598
) as state:
15961599
# No other event handler can modify the state while in this context.
15971600
yield state
@@ -1624,7 +1627,7 @@ def _process_background(
16241627
if not handler.is_background:
16251628
return None
16261629

1627-
substate = StateProxy(substate)
1630+
substate = StateProxy(substate, event)
16281631

16291632
async def _coro():
16301633
"""Coroutine to process the event and emit updates inside an asyncio.Task.
@@ -2010,7 +2013,7 @@ async def _ndjson_updates():
20102013
"""
20112014
# Process the event.
20122015
async with app.state_manager.modify_state_with_links(
2013-
event.substate_token
2016+
event.substate_token, event=event
20142017
) as state:
20152018
async for update in state._process(event):
20162019
# Postprocess the event.

reflex/istate/manager/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class StateModificationContext(TypedDict, total=False):
2020
"""The context for modifying state."""
2121

22-
event: ReadOnly[Event]
22+
event: ReadOnly[Event | None]
2323

2424

2525
EmptyContext = StateModificationContext()

reflex/istate/manager/redis.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,12 +363,14 @@ async def set_state(
363363
# Check that we're holding the lock.
364364
if (
365365
lock_id is not None
366-
and await self.redis.get(self._lock_key(token)) != lock_id
366+
and (existing_lock_id := await self.redis.get(self._lock_key(token)))
367+
!= lock_id
367368
):
368369
msg = (
369370
f"Lock expired for token {token} while processing. Consider increasing "
370371
f"`app.state_manager.lock_expiration` (currently {self.lock_expiration}) "
371-
"or use `@rx.event(background=True)` decorator for long-running tasks."
372+
"or use `@rx.event(background=True)` decorator for long-running tasks. "
373+
f"Current lock id: {existing_lock_id!r}, expected lock id: {lock_id!r}."
372374
+ (
373375
f" Happened in event: {event.name}"
374376
if (event := context.get("event")) is not None
@@ -440,9 +442,10 @@ async def _try_modify_state(
440442
Yields:
441443
The state for the token or None if we couldn't get the lock.
442444
"""
445+
event_name = event.name if (event := context.get("event")) is not None else None
443446
if not self._oplock_enabled:
444447
# OpLock is disabled, get a fresh lock, write, and release.
445-
async with self._lock(token) as lock_id:
448+
async with self._lock(token, event_name=event_name) as lock_id:
446449
state = await self.get_state(token)
447450
yield state
448451
await self.set_state(token, state, lock_id=lock_id, **context)
@@ -459,7 +462,9 @@ async def _try_modify_state(
459462
client_token, _ = _split_substate_key(token)
460463
lock_held_ctx = contextlib.AsyncExitStack()
461464
try:
462-
lock_id = await lock_held_ctx.enter_async_context(self._lock(token))
465+
lock_id = await lock_held_ctx.enter_async_context(
466+
self._lock(token, event_name=event_name)
467+
)
463468
except OplockFound:
464469
# While waiting for the lock, another process has acquired it, but we can piggy back.
465470
pass
@@ -1000,11 +1005,14 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
10001005
)
10011006

10021007
@contextlib.asynccontextmanager
1003-
async def _lock(self, token: str):
1008+
async def _lock(
1009+
self, token: str, event_name: str | None = None
1010+
) -> AsyncIterator[bytes]:
10041011
"""Obtain a redis lock for a token.
10051012
10061013
Args:
10071014
token: The token to obtain a lock for.
1015+
event_name: The name of the event associated with the lock.
10081016
10091017
Yields:
10101018
The ID of the lock (to be passed to set_state).
@@ -1013,7 +1021,9 @@ async def _lock(self, token: str):
10131021
LockExpiredError: If the lock has expired while processing the event.
10141022
"""
10151023
lock_key = self._lock_key(token)
1016-
lock_id = uuid.uuid4().hex.encode()
1024+
lock_id = (
1025+
f"{event_name}:{uuid.uuid4().hex}" if event_name else uuid.uuid4().hex
1026+
).encode()
10171027

10181028
await self._wait_lock(lock_key, lock_id)
10191029
state_is_locked = True

reflex/istate/proxy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing_extensions import Self
1919

2020
from reflex.base import Base
21+
from reflex.event import Event
2122
from reflex.utils import prerequisites
2223
from reflex.utils.exceptions import ImmutableStateError
2324
from reflex.utils.serializers import can_serialize, serialize, serializer
@@ -59,6 +60,7 @@ async def bg_increment(self):
5960
def __init__(
6061
self,
6162
state_instance: BaseState,
63+
event: Event | None = None,
6264
parent_state_proxy: StateProxy | None = None,
6365
):
6466
"""Create a proxy for a state instance.
@@ -69,11 +71,13 @@ def __init__(
6971
7072
Args:
7173
state_instance: The state instance to proxy.
74+
event: The event associated with the state modification context.
7275
parent_state_proxy: The parent state proxy, for linked mutability and context tracking.
7376
"""
7477
from reflex.state import _substate_key
7578

7679
super().__init__(state_instance)
80+
self._self_event = event
7781
self._self_app = prerequisites.get_and_validate_app().app
7882
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
7983
self._self_substate_token = _substate_key(
@@ -136,7 +140,7 @@ async def __aenter__(self) -> Self:
136140
try:
137141
self._self_actx_lock_holder = current_task
138142
self._self_actx = self._self_app.modify_state(
139-
token=self._self_substate_token, background=True
143+
token=self._self_substate_token, background=True, event=self._self_event
140144
)
141145
mutable_state = await self._self_actx.__aenter__()
142146
self._self_mutable = True
@@ -294,7 +298,9 @@ async def get_state(self, state_cls: type[T_STATE]) -> T_STATE:
294298
)
295299
raise ImmutableStateError(msg)
296300
return type(self)(
297-
await self.__wrapped__.get_state(state_cls), parent_state_proxy=self
301+
await self.__wrapped__.get_state(state_cls),
302+
event=self._self_event,
303+
parent_state_proxy=self,
298304
) # pyright: ignore [reportReturnType]
299305

300306
async def _as_state_update(self, *args, **kwargs) -> StateUpdate:

reflex/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1749,7 +1749,7 @@ async def _process(self, event: Event) -> AsyncIterator[StateUpdate]:
17491749

17501750
# For background tasks, proxy the state.
17511751
if handler.is_background:
1752-
substate = StateProxy(substate)
1752+
substate = StateProxy(substate, event)
17531753

17541754
# Run the event generator and yield state updates.
17551755
async for update in self._process_event(

0 commit comments

Comments
 (0)