Skip to content

Commit ab71d84

Browse files
committed
Add in-memory state expiration to StateManagerMemory
Extract reusable expiration logic into StateManagerExpiration base class that tracks token access times and purges expired states using a deadline-ordered heap. Integrate it into StateManagerMemory with a background asyncio task that automatically cleans up idle client states.
1 parent 7ee3026 commit ab71d84

4 files changed

Lines changed: 623 additions & 9 deletions

File tree

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""Internal helpers for in-memory state expiration."""
2+
3+
import asyncio
4+
import contextlib
5+
import dataclasses
6+
import heapq
7+
import time
8+
from typing import ClassVar
9+
10+
from reflex.state import BaseState
11+
12+
from . import _default_token_expiration
13+
14+
15+
@dataclasses.dataclass
16+
class StateManagerExpiration:
17+
"""Internal base for managers with in-memory state expiration."""
18+
19+
_locked_expiration_poll_interval: ClassVar[float] = 0.1
20+
_recheck_expired_locks_on_unlock: ClassVar[bool] = False
21+
22+
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
23+
24+
# The mapping of client ids to states.
25+
states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
26+
27+
# The dict of mutexes for each client.
28+
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
29+
default_factory=dict,
30+
init=False,
31+
)
32+
33+
# The latest expiration deadline for each token.
34+
_token_expires_at: dict[str, float] = dataclasses.field(
35+
default_factory=dict,
36+
init=False,
37+
)
38+
39+
# Last time a token was touched.
40+
_token_last_touched: dict[str, float] = dataclasses.field(
41+
default_factory=dict,
42+
init=False,
43+
)
44+
45+
# Deadline-ordered token expiration heap.
46+
_token_expiration_heap: list[tuple[float, str]] = dataclasses.field(
47+
default_factory=list,
48+
init=False,
49+
repr=False,
50+
)
51+
52+
# Tokens whose expiration is deferred until their state lock is released.
53+
_pending_locked_expirations: set[str] = dataclasses.field(
54+
default_factory=set,
55+
init=False,
56+
repr=False,
57+
)
58+
59+
# Wake any background expiration worker when token activity changes.
60+
_token_activity: asyncio.Event = dataclasses.field(
61+
default_factory=asyncio.Event,
62+
init=False,
63+
repr=False,
64+
)
65+
66+
_scheduled_expiration_deadline: float | None = dataclasses.field(
67+
default=None,
68+
init=False,
69+
repr=False,
70+
)
71+
72+
def _touch_token(self, token: str):
73+
"""Record access for a token.
74+
75+
Args:
76+
token: The token that was accessed.
77+
"""
78+
touched_at = time.time()
79+
expires_at = touched_at + self.token_expiration
80+
self._token_last_touched[token] = touched_at
81+
self._token_expires_at[token] = expires_at
82+
self._pending_locked_expirations.discard(token)
83+
heapq.heappush(self._token_expiration_heap, (expires_at, token))
84+
self._maybe_compact_expiration_heap()
85+
if (
86+
self._scheduled_expiration_deadline is None
87+
or expires_at <= self._scheduled_expiration_deadline
88+
):
89+
self._token_activity.set()
90+
91+
def _maybe_compact_expiration_heap(self):
92+
"""Rebuild the heap when stale deadline entries accumulate."""
93+
if len(self._token_expiration_heap) <= (2 * len(self._token_expires_at)) + 1:
94+
return
95+
self._token_expiration_heap = [
96+
(expires_at, token)
97+
for token, expires_at in self._token_expires_at.items()
98+
if token not in self._pending_locked_expirations
99+
]
100+
heapq.heapify(self._token_expiration_heap)
101+
102+
def _next_expiration(self) -> tuple[float, str] | None:
103+
"""Get the next valid token expiration from the heap.
104+
105+
Returns:
106+
The next expiration deadline and token, or None if there are no
107+
active deadlines to process.
108+
"""
109+
while self._token_expiration_heap:
110+
expires_at, token = self._token_expiration_heap[0]
111+
current_expiration = self._token_expires_at.get(token)
112+
if (
113+
current_expiration != expires_at
114+
or token in self._pending_locked_expirations
115+
):
116+
heapq.heappop(self._token_expiration_heap)
117+
continue
118+
return expires_at, token
119+
return None
120+
121+
def _purge_token(self, token: str):
122+
"""Remove a token from all in-memory expiration bookkeeping.
123+
124+
Args:
125+
token: The token to purge.
126+
"""
127+
self._token_last_touched.pop(token, None)
128+
self._token_expires_at.pop(token, None)
129+
self.states.pop(token, None)
130+
self._states_locks.pop(token, None)
131+
self._pending_locked_expirations.discard(token)
132+
133+
def _purge_expired_tokens(
134+
self,
135+
now: float | None = None,
136+
) -> list[str]:
137+
"""Purge expired in-memory state entries.
138+
139+
If a token's state lock is currently held, defer cleanup until a later pass
140+
to avoid replacing the state while it is being modified.
141+
142+
Args:
143+
now: The time to compare against.
144+
145+
Returns:
146+
The list of purged tokens.
147+
"""
148+
now = time.time() if now is None else now
149+
expired_tokens = []
150+
while (
151+
next_expiration := self._next_expiration()
152+
) is not None and next_expiration[0] <= now:
153+
_expires_at, token = heapq.heappop(self._token_expiration_heap)
154+
if (
155+
state_lock := self._states_locks.get(token)
156+
) is not None and state_lock.locked():
157+
self._pending_locked_expirations.add(token)
158+
continue
159+
self._purge_token(token)
160+
expired_tokens.append(token)
161+
return expired_tokens
162+
163+
def _next_expiration_in(
164+
self,
165+
now: float | None = None,
166+
) -> float | None:
167+
"""Get the delay until the next expiration check should run.
168+
169+
Args:
170+
now: The time to compare against.
171+
172+
Returns:
173+
The number of seconds until the next check, or None when there are no
174+
tracked tokens.
175+
"""
176+
if (next_expiration := self._next_expiration()) is None:
177+
if (
178+
self._pending_locked_expirations
179+
and not self._recheck_expired_locks_on_unlock
180+
):
181+
return self._locked_expiration_poll_interval
182+
return None
183+
184+
now = time.time() if now is None else now
185+
next_delay = max(0.0, next_expiration[0] - now)
186+
if (
187+
self._pending_locked_expirations
188+
and not self._recheck_expired_locks_on_unlock
189+
):
190+
return min(next_delay, self._locked_expiration_poll_interval)
191+
return next_delay
192+
193+
def _reset_token_activity_wait(self):
194+
"""Reset the token activity event before waiting."""
195+
self._token_activity.clear()
196+
197+
def _prepare_expiration_wait(
198+
self,
199+
*,
200+
now: float | None = None,
201+
default_timeout: float | None = None,
202+
) -> float | None:
203+
"""Prepare the next wait window for an expiration worker.
204+
205+
Args:
206+
now: The current time.
207+
default_timeout: A fallback timeout when there are no in-memory token
208+
deadlines to wait on.
209+
210+
Returns:
211+
The timeout to use for the next wait.
212+
"""
213+
self._reset_token_activity_wait()
214+
now = time.time() if now is None else now
215+
timeout = self._next_expiration_in(now=now)
216+
if timeout is None:
217+
timeout = default_timeout
218+
elif default_timeout is not None:
219+
timeout = min(timeout, default_timeout)
220+
self._scheduled_expiration_deadline = None if timeout is None else now + timeout
221+
return timeout
222+
223+
def _notify_token_unlocked(self, token: str):
224+
"""Requeue a deferred expiration check for a token after its lock is released.
225+
226+
Args:
227+
token: The unlocked token.
228+
"""
229+
if token not in self._pending_locked_expirations:
230+
return
231+
self._pending_locked_expirations.discard(token)
232+
if (expires_at := self._token_expires_at.get(token)) is None:
233+
return
234+
heapq.heappush(self._token_expiration_heap, (expires_at, token))
235+
self._token_activity.set()
236+
237+
async def _wait_for_token_activity(self, timeout: float | None):
238+
"""Wait for token activity or a timeout.
239+
240+
Args:
241+
timeout: The maximum time to wait. When None, waits indefinitely.
242+
"""
243+
try:
244+
if timeout is None:
245+
await self._token_activity.wait()
246+
return
247+
with contextlib.suppress(asyncio.TimeoutError):
248+
await asyncio.wait_for(self._token_activity.wait(), timeout=timeout)
249+
finally:
250+
self._scheduled_expiration_deadline = None

reflex/istate/manager/memory.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,64 @@
33
import asyncio
44
import contextlib
55
import dataclasses
6+
import time
67
from collections.abc import AsyncIterator
8+
from typing import ClassVar
79

810
from typing_extensions import Unpack, override
911

1012
from reflex.istate.manager import StateManager, StateModificationContext
13+
from reflex.istate.manager._expiration import StateManagerExpiration
1114
from reflex.state import BaseState, _split_substate_key
15+
from reflex.utils import console
16+
17+
_EXPIRATION_ERROR_RETRY_SECONDS = 1.0
1218

1319

1420
@dataclasses.dataclass
15-
class StateManagerMemory(StateManager):
21+
class StateManagerMemory(StateManagerExpiration, StateManager):
1622
"""A state manager that stores states in memory."""
1723

18-
# The mapping of client ids to states.
19-
states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
24+
_recheck_expired_locks_on_unlock: ClassVar[bool] = True
2025

2126
# The mutex ensures the dict of mutexes is updated exclusively
2227
_state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
2328

24-
# The dict of mutexes for each client
25-
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
26-
default_factory=dict, init=False
27-
)
29+
_expiration_task: asyncio.Task | None = None
30+
31+
async def _expire_states_once(self):
32+
"""Perform one expiration pass and wait for the next check."""
33+
try:
34+
now = time.time()
35+
self._purge_expired_tokens(now=now)
36+
await self._wait_for_token_activity(
37+
self._prepare_expiration_wait(now=now),
38+
)
39+
except asyncio.CancelledError:
40+
raise
41+
except Exception as err:
42+
console.error(f"Error expiring in-memory states: {err!r}")
43+
await asyncio.sleep(_EXPIRATION_ERROR_RETRY_SECONDS)
44+
45+
async def _expire_states(self):
46+
"""Long running task that removes expired states from memory.
47+
48+
Raises:
49+
asyncio.CancelledError: When the task is cancelled.
50+
"""
51+
while True:
52+
await self._expire_states_once()
53+
54+
async def _schedule_expiration_task(self):
55+
"""Schedule the expiration task if it is not already running."""
56+
if self._expiration_task is None or self._expiration_task.done():
57+
async with self._state_manager_lock:
58+
if self._expiration_task is None or self._expiration_task.done():
59+
self._expiration_task = asyncio.create_task(
60+
self._expire_states(),
61+
name="StateManagerMemory|ExpirationProcessor",
62+
)
63+
await asyncio.sleep(0)
2864

2965
@override
3066
async def get_state(self, token: str) -> BaseState:
@@ -38,6 +74,8 @@ async def get_state(self, token: str) -> BaseState:
3874
"""
3975
# Memory state manager ignores the substate suffix and always returns the top-level state.
4076
token = _split_substate_key(token)[0]
77+
self._touch_token(token)
78+
await self._schedule_expiration_task()
4179
if token not in self.states:
4280
self.states[token] = self.state(_reflex_internal_init=True)
4381
return self.states[token]
@@ -57,7 +95,9 @@ async def set_state(
5795
context: The state modification context.
5896
"""
5997
token = _split_substate_key(token)[0]
98+
self._touch_token(token)
6099
self.states[token] = state
100+
await self._schedule_expiration_task()
61101

62102
@override
63103
@contextlib.asynccontextmanager
@@ -80,5 +120,17 @@ async def modify_state(
80120
if token not in self._states_locks:
81121
self._states_locks[token] = asyncio.Lock()
82122

83-
async with self._states_locks[token]:
84-
yield await self.get_state(token)
123+
try:
124+
async with self._states_locks[token]:
125+
yield await self.get_state(token)
126+
finally:
127+
self._notify_token_unlocked(token)
128+
129+
async def close(self):
130+
"""Cancel the in-memory expiration task."""
131+
async with self._state_manager_lock:
132+
if self._expiration_task:
133+
self._expiration_task.cancel()
134+
with contextlib.suppress(asyncio.CancelledError):
135+
await self._expiration_task
136+
self._expiration_task = None

0 commit comments

Comments
 (0)