Skip to content

Commit c0b8742

Browse files
committed
Simplify StateManagerMemory expiration internals
Remove dead _token_last_touched dict, replace hand-rolled task scheduling with ensure_task, move heap compaction off the hot path, and fix touch ordering in get_state/set_state.
1 parent ab71d84 commit c0b8742

3 files changed

Lines changed: 35 additions & 66 deletions

File tree

reflex/istate/manager/_expiration.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Internal helpers for in-memory state expiration."""
22

3+
from __future__ import annotations
4+
35
import asyncio
46
import contextlib
57
import dataclasses
68
import heapq
79
import time
8-
from typing import ClassVar
9-
10-
from reflex.state import BaseState
10+
from typing import TYPE_CHECKING, ClassVar
1111

1212
from . import _default_token_expiration
1313

14+
if TYPE_CHECKING:
15+
from reflex.state import BaseState
16+
1417

1518
@dataclasses.dataclass
1619
class StateManagerExpiration:
@@ -36,12 +39,6 @@ class StateManagerExpiration:
3639
init=False,
3740
)
3841

39-
# Last time a token was touched.
40-
_token_last_touched: dict[str, float] = dataclasses.field(
41-
default_factory=dict,
42-
init=False,
43-
)
44-
4542
# Deadline-ordered token expiration heap.
4643
_token_expiration_heap: list[tuple[float, str]] = dataclasses.field(
4744
default_factory=list,
@@ -75,13 +72,10 @@ def _touch_token(self, token: str):
7572
Args:
7673
token: The token that was accessed.
7774
"""
78-
touched_at = time.time()
79-
expires_at = touched_at + self.token_expiration
80-
self._token_last_touched[token] = touched_at
75+
expires_at = time.time() + self.token_expiration
8176
self._token_expires_at[token] = expires_at
8277
self._pending_locked_expirations.discard(token)
8378
heapq.heappush(self._token_expiration_heap, (expires_at, token))
84-
self._maybe_compact_expiration_heap()
8579
if (
8680
self._scheduled_expiration_deadline is None
8781
or expires_at <= self._scheduled_expiration_deadline
@@ -124,7 +118,6 @@ def _purge_token(self, token: str):
124118
Args:
125119
token: The token to purge.
126120
"""
127-
self._token_last_touched.pop(token, None)
128121
self._token_expires_at.pop(token, None)
129122
self.states.pop(token, None)
130123
self._states_locks.pop(token, None)
@@ -133,20 +126,16 @@ def _purge_token(self, token: str):
133126
def _purge_expired_tokens(
134127
self,
135128
now: float | None = None,
136-
) -> list[str]:
129+
):
137130
"""Purge expired in-memory state entries.
138131
139132
If a token's state lock is currently held, defer cleanup until a later pass
140133
to avoid replacing the state while it is being modified.
141134
142135
Args:
143136
now: The time to compare against.
144-
145-
Returns:
146-
The list of purged tokens.
147137
"""
148138
now = time.time() if now is None else now
149-
expired_tokens = []
150139
while (
151140
next_expiration := self._next_expiration()
152141
) is not None and next_expiration[0] <= now:
@@ -157,8 +146,7 @@ def _purge_expired_tokens(
157146
self._pending_locked_expirations.add(token)
158147
continue
159148
self._purge_token(token)
160-
expired_tokens.append(token)
161-
return expired_tokens
149+
self._maybe_compact_expiration_heap()
162150

163151
def _next_expiration_in(
164152
self,

reflex/istate/manager/memory.py

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
from reflex.istate.manager import StateManager, StateModificationContext
1313
from reflex.istate.manager._expiration import StateManagerExpiration
1414
from reflex.state import BaseState, _split_substate_key
15-
from reflex.utils import console
16-
17-
_EXPIRATION_ERROR_RETRY_SECONDS = 1.0
15+
from reflex.utils.tasks import ensure_task
1816

1917

2018
@dataclasses.dataclass
@@ -26,41 +24,24 @@ class StateManagerMemory(StateManagerExpiration, StateManager):
2624
# The mutex ensures the dict of mutexes is updated exclusively
2725
_state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock())
2826

29-
_expiration_task: asyncio.Task | None = None
27+
_expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False)
3028

3129
async def _expire_states_once(self):
3230
"""Perform one expiration pass and wait for the next check."""
33-
try:
34-
now = time.time()
35-
self._purge_expired_tokens(now=now)
36-
await self._wait_for_token_activity(
37-
self._prepare_expiration_wait(now=now),
38-
)
39-
except asyncio.CancelledError:
40-
raise
41-
except Exception as err:
42-
console.error(f"Error expiring in-memory states: {err!r}")
43-
await asyncio.sleep(_EXPIRATION_ERROR_RETRY_SECONDS)
44-
45-
async def _expire_states(self):
46-
"""Long running task that removes expired states from memory.
47-
48-
Raises:
49-
asyncio.CancelledError: When the task is cancelled.
50-
"""
51-
while True:
52-
await self._expire_states_once()
53-
54-
async def _schedule_expiration_task(self):
55-
"""Schedule the expiration task if it is not already running."""
56-
if self._expiration_task is None or self._expiration_task.done():
57-
async with self._state_manager_lock:
58-
if self._expiration_task is None or self._expiration_task.done():
59-
self._expiration_task = asyncio.create_task(
60-
self._expire_states(),
61-
name="StateManagerMemory|ExpirationProcessor",
62-
)
63-
await asyncio.sleep(0)
31+
now = time.time()
32+
self._purge_expired_tokens(now=now)
33+
await self._wait_for_token_activity(
34+
self._prepare_expiration_wait(now=now),
35+
)
36+
37+
def _ensure_expiration_task(self):
38+
"""Ensure the expiration background task is running."""
39+
ensure_task(
40+
self,
41+
"_expiration_task",
42+
self._expire_states_once,
43+
suppress_exceptions=[Exception],
44+
)
6445

6546
@override
6647
async def get_state(self, token: str) -> BaseState:
@@ -74,10 +55,10 @@ async def get_state(self, token: str) -> BaseState:
7455
"""
7556
# Memory state manager ignores the substate suffix and always returns the top-level state.
7657
token = _split_substate_key(token)[0]
77-
self._touch_token(token)
78-
await self._schedule_expiration_task()
7958
if token not in self.states:
8059
self.states[token] = self.state(_reflex_internal_init=True)
60+
self._touch_token(token)
61+
self._ensure_expiration_task()
8162
return self.states[token]
8263

8364
@override
@@ -95,9 +76,9 @@ async def set_state(
9576
context: The state modification context.
9677
"""
9778
token = _split_substate_key(token)[0]
98-
self._touch_token(token)
9979
self.states[token] = state
100-
await self._schedule_expiration_task()
80+
self._touch_token(token)
81+
self._ensure_expiration_task()
10182

10283
@override
10384
@contextlib.asynccontextmanager

tests/units/istate/manager/test_expiration.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ async def test_memory_state_manager_evicts_expired_state(
6363

6464
assert token in state_manager_memory.states
6565
assert token in state_manager_memory._states_locks
66-
assert token in state_manager_memory._token_last_touched
66+
assert token in state_manager_memory._token_expires_at
6767

6868
await _poll_until(
6969
lambda: (
7070
token not in state_manager_memory.states
7171
and token not in state_manager_memory._states_locks
72-
and token not in state_manager_memory._token_last_touched
72+
and token not in state_manager_memory._token_expires_at
7373
)
7474
)
7575

@@ -84,13 +84,13 @@ async def test_memory_state_manager_get_state_refreshes_expiration(
8484
state = await state_manager_memory.get_state(state_token)
8585
assert isinstance(state, ExpiringState)
8686
state.value = 7
87-
first_touch = state_manager_memory._token_last_touched[token]
87+
first_expires_at = state_manager_memory._token_expires_at[token]
8888

8989
await asyncio.sleep(0.6)
9090

9191
same_state = await state_manager_memory.get_state(state_token)
9292
assert same_state is state
93-
assert state_manager_memory._token_last_touched[token] > first_touch
93+
assert state_manager_memory._token_expires_at[token] > first_expires_at
9494

9595
await asyncio.sleep(0.6)
9696

@@ -108,13 +108,13 @@ async def test_memory_state_manager_set_state_refreshes_expiration(
108108
state = await state_manager_memory.get_state(state_token)
109109
assert isinstance(state, ExpiringState)
110110
state.value = 17
111-
first_touch = state_manager_memory._token_last_touched[token]
111+
first_expires_at = state_manager_memory._token_expires_at[token]
112112

113113
await asyncio.sleep(0.6)
114114

115115
await state_manager_memory.set_state(state_token, state)
116116

117-
assert state_manager_memory._token_last_touched[token] > first_touch
117+
assert state_manager_memory._token_expires_at[token] > first_expires_at
118118

119119
await asyncio.sleep(0.6)
120120

0 commit comments

Comments
 (0)