Skip to content

Commit 931ac2c

Browse files
Add in-memory state expiration to StateManagerMemory (#6201)
* Add in-memory state expiration to StateManagerMemory Extract reusable expiration logic into StateManagerExpiration base class that tracks token access times and purges expired states using a deadline-ordered heap. Integrate it into StateManagerMemory with a background asyncio task that automatically cleans up idle client states. * Simplify StateManagerMemory expiration internals Remove dead _token_last_touched dict, replace hand-rolled task scheduling with ensure_task, move heap compaction off the hot path, and fix touch ordering in get_state/set_state. * Fix StateManager.create to respect explicit memory mode when Redis URL is set Previously, StateManager.create always overrode the mode to REDIS when a Redis URL was detected, ignoring an explicitly configured memory mode. Now it only auto-promotes to REDIS when state_manager_mode was not explicitly set. Adds a test verifying the explicit mode is honored. * feat: add in-memory state expiration to StateManagerMemory Implement automatic eviction of idle client states in the memory state manager using a heap-based expiration system. States are touched on get/set and purged after token_expiration seconds of inactivity. Locked states defer eviction until their lock is released. Also fix StateManager.create to respect an explicit memory mode when a Redis URL is configured. * feat: Addded expire extension
1 parent b1a1c78 commit 931ac2c

5 files changed

Lines changed: 533 additions & 13 deletions

File tree

reflex/istate/manager/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def create(cls, state: type[BaseState]):
4949
InvalidStateManagerModeError: If the state manager mode is invalid.
5050
"""
5151
config = get_config()
52-
if prerequisites.parse_redis_url() is not None:
52+
if (
53+
"state_manager_mode" not in config._non_default_attributes
54+
and prerequisites.parse_redis_url() is not None
55+
):
5356
config.state_manager_mode = constants.StateManagerMode.REDIS
5457
if config.state_manager_mode == constants.StateManagerMode.MEMORY:
5558
from reflex.istate.manager.memory import StateManagerMemory

reflex/istate/manager/memory.py

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,26 @@
33
import asyncio
44
import contextlib
55
import dataclasses
6+
import time
67
from collections.abc import AsyncIterator
78

89
from typing_extensions import Unpack, override
910

10-
from reflex.istate.manager import StateManager, StateModificationContext
11+
from reflex.istate.manager import (
12+
StateManager,
13+
StateModificationContext,
14+
_default_token_expiration,
15+
)
1116
from reflex.state import BaseState, _split_substate_key
1217

1318

1419
@dataclasses.dataclass
1520
class StateManagerMemory(StateManager):
1621
"""A state manager that stores states in memory."""
1722

23+
# The token expiration time (s).
24+
token_expiration: int = dataclasses.field(default_factory=_default_token_expiration)
25+
1826
# The mapping of client ids to states.
1927
states: dict[str, BaseState] = dataclasses.field(default_factory=dict)
2028

@@ -23,9 +31,104 @@ class StateManagerMemory(StateManager):
2331

2432
# The dict of mutexes for each client
2533
_states_locks: dict[str, asyncio.Lock] = dataclasses.field(
26-
default_factory=dict, init=False
34+
default_factory=dict,
35+
init=False,
36+
)
37+
38+
# The latest expiration deadline for each token.
39+
_token_expires_at: dict[str, float] = dataclasses.field(
40+
default_factory=dict,
41+
init=False,
2742
)
2843

44+
_expiration_task: asyncio.Task | None = dataclasses.field(default=None, init=False)
45+
46+
def _get_or_create_state(self, token: str) -> BaseState:
47+
"""Get an existing state or create a fresh one for a token.
48+
49+
Args:
50+
token: The normalized client token.
51+
52+
Returns:
53+
The state for the token.
54+
"""
55+
state = self.states.get(token)
56+
if state is None:
57+
state = self.states[token] = self.state(_reflex_internal_init=True)
58+
return state
59+
60+
def _track_token(self, token: str):
61+
"""Refresh the expiration deadline for an active token."""
62+
self._token_expires_at[token] = time.time() + self.token_expiration
63+
self._ensure_expiration_task()
64+
65+
def _purge_token(self, token: str):
66+
"""Remove a token from in-memory state bookkeeping."""
67+
self._token_expires_at.pop(token, None)
68+
self.states.pop(token, None)
69+
self._states_locks.pop(token, None)
70+
71+
def _purge_expired_tokens(self) -> float | None:
72+
"""Purge expired in-memory state entries and return the next deadline.
73+
74+
Returns:
75+
The next expiration deadline among unlocked tokens, if any.
76+
"""
77+
now = time.time()
78+
next_expires_at = None
79+
token_expires_at = self._token_expires_at
80+
state_locks = self._states_locks
81+
82+
for token, expires_at in list(token_expires_at.items()):
83+
if (
84+
state_lock := state_locks.get(token)
85+
) is not None and state_lock.locked():
86+
continue
87+
if expires_at <= now:
88+
self._purge_token(token)
89+
continue
90+
if next_expires_at is None or expires_at < next_expires_at:
91+
next_expires_at = expires_at
92+
93+
return next_expires_at
94+
95+
async def _get_state_lock(self, token: str) -> asyncio.Lock:
96+
"""Get or create the lock for a token.
97+
98+
Args:
99+
token: The normalized client token.
100+
101+
Returns:
102+
The lock protecting the token's state.
103+
"""
104+
state_lock = self._states_locks.get(token)
105+
if state_lock is None:
106+
async with self._state_manager_lock:
107+
state_lock = self._states_locks.get(token)
108+
if state_lock is None:
109+
state_lock = self._states_locks[token] = asyncio.Lock()
110+
return state_lock
111+
112+
async def _expire_states(self):
113+
"""Purge expired states until there are no unlocked deadlines left."""
114+
try:
115+
while True:
116+
if (next_expires_at := self._purge_expired_tokens()) is None:
117+
return
118+
await asyncio.sleep(max(0.0, next_expires_at - time.time()))
119+
finally:
120+
if self._expiration_task is asyncio.current_task():
121+
self._expiration_task = None
122+
123+
def _ensure_expiration_task(self):
124+
"""Ensure the expiration background task is running."""
125+
if self._expiration_task is None or self._expiration_task.done():
126+
asyncio.get_running_loop() # Ensure we're in an event loop.
127+
self._expiration_task = asyncio.create_task(
128+
self._expire_states(),
129+
name="StateManagerMemory|Expiration",
130+
)
131+
29132
@override
30133
async def get_state(self, token: str) -> BaseState:
31134
"""Get the state for a token.
@@ -38,9 +141,9 @@ async def get_state(self, token: str) -> BaseState:
38141
"""
39142
# Memory state manager ignores the substate suffix and always returns the top-level state.
40143
token = _split_substate_key(token)[0]
41-
if token not in self.states:
42-
self.states[token] = self.state(_reflex_internal_init=True)
43-
return self.states[token]
144+
state = self._get_or_create_state(token)
145+
self._track_token(token)
146+
return state
44147

45148
@override
46149
async def set_state(
@@ -58,6 +161,7 @@ async def set_state(
58161
"""
59162
token = _split_substate_key(token)[0]
60163
self.states[token] = state
164+
self._track_token(token)
61165

62166
@override
63167
@contextlib.asynccontextmanager
@@ -75,10 +179,28 @@ async def modify_state(
75179
"""
76180
# Memory state manager ignores the substate suffix and always returns the top-level state.
77181
token = _split_substate_key(token)[0]
78-
if token not in self._states_locks:
79-
async with self._state_manager_lock:
80-
if token not in self._states_locks:
81-
self._states_locks[token] = asyncio.Lock()
82-
83-
async with self._states_locks[token]:
84-
yield await self.get_state(token)
182+
state_lock = await self._get_state_lock(token)
183+
184+
try:
185+
async with state_lock:
186+
state = self._get_or_create_state(token)
187+
self._track_token(token)
188+
try:
189+
yield state
190+
finally:
191+
# Treat modify_state like a read followed by a write so the
192+
# expiration window starts after the state is no longer busy.
193+
self._track_token(token)
194+
finally:
195+
# Re-run expiration after the lock is released in case only locked
196+
# tokens were being tracked when the worker last ran.
197+
self._ensure_expiration_task()
198+
199+
async def close(self):
200+
"""Cancel the in-memory expiration task."""
201+
async with self._state_manager_lock:
202+
if self._expiration_task:
203+
self._expiration_task.cancel()
204+
with contextlib.suppress(asyncio.CancelledError):
205+
await self._expiration_task
206+
self._expiration_task = None
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""Integration tests for in-memory state expiration."""
2+
3+
import time
4+
from collections.abc import Generator
5+
6+
import pytest
7+
from selenium.webdriver.common.by import By
8+
from selenium.webdriver.remote.webdriver import WebDriver
9+
10+
from reflex.istate.manager.memory import StateManagerMemory
11+
from reflex.testing import AppHarness
12+
13+
14+
def MemoryExpirationApp():
15+
"""Reflex app that exposes state expiration through a simple counter UI."""
16+
import reflex as rx
17+
18+
class State(rx.State):
19+
counter: int = 0
20+
21+
@rx.event
22+
def increment(self):
23+
self.counter += 1
24+
25+
app = rx.App()
26+
27+
@app.add_page
28+
def index():
29+
return rx.vstack(
30+
rx.input(
31+
id="token",
32+
value=State.router.session.client_token,
33+
is_read_only=True,
34+
),
35+
rx.text(State.counter, id="counter"),
36+
rx.button("Increment", id="increment", on_click=State.increment),
37+
)
38+
39+
40+
@pytest.fixture
41+
def memory_expiration_app(
42+
app_harness_env: type[AppHarness],
43+
monkeypatch: pytest.MonkeyPatch,
44+
tmp_path_factory: pytest.TempPathFactory,
45+
) -> Generator[AppHarness, None, None]:
46+
"""Start a memory-backed app with a short expiration window.
47+
48+
Yields:
49+
A running app harness configured to use StateManagerMemory.
50+
"""
51+
monkeypatch.setenv("REFLEX_STATE_MANAGER_MODE", "memory")
52+
# Memory expiration reuses the shared token_expiration config field.
53+
monkeypatch.setenv("REFLEX_REDIS_TOKEN_EXPIRATION", "1")
54+
55+
with app_harness_env.create(
56+
root=tmp_path_factory.mktemp("memory_expiration_app"),
57+
app_name=f"memory_expiration_{app_harness_env.__name__.lower()}",
58+
app_source=MemoryExpirationApp,
59+
) as harness:
60+
assert isinstance(harness.state_manager, StateManagerMemory)
61+
yield harness
62+
63+
64+
@pytest.fixture
65+
def driver(memory_expiration_app: AppHarness) -> Generator[WebDriver, None, None]:
66+
"""Open the memory expiration app in a browser.
67+
68+
Yields:
69+
A webdriver instance pointed at the running app.
70+
"""
71+
assert memory_expiration_app.app_instance is not None, "app is not running"
72+
driver = memory_expiration_app.frontend()
73+
try:
74+
yield driver
75+
finally:
76+
driver.quit()
77+
78+
79+
def test_memory_state_manager_expires_state_end_to_end(
80+
memory_expiration_app: AppHarness,
81+
driver: WebDriver,
82+
):
83+
"""An idle in-memory state should expire and reset on the next event."""
84+
app_instance = memory_expiration_app.app_instance
85+
assert app_instance is not None
86+
87+
token_input = AppHarness.poll_for_or_raise_timeout(
88+
lambda: driver.find_element(By.ID, "token")
89+
)
90+
token = memory_expiration_app.poll_for_value(token_input)
91+
assert token is not None
92+
93+
counter = driver.find_element(By.ID, "counter")
94+
increment = driver.find_element(By.ID, "increment")
95+
app_state_manager = app_instance.state_manager
96+
assert isinstance(app_state_manager, StateManagerMemory)
97+
98+
AppHarness.expect(lambda: counter.text == "0")
99+
100+
increment.click()
101+
AppHarness.expect(lambda: counter.text == "1")
102+
103+
increment.click()
104+
AppHarness.expect(lambda: counter.text == "2")
105+
106+
AppHarness.expect(lambda: token in app_state_manager.states)
107+
AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5)
108+
109+
increment.click()
110+
AppHarness.expect(lambda: counter.text == "1")
111+
assert token_input.get_attribute("value") == token
112+
113+
114+
def test_memory_state_manager_delays_expiration_after_use_end_to_end(
115+
memory_expiration_app: AppHarness,
116+
driver: WebDriver,
117+
):
118+
"""Using a token should start a fresh expiration window from the last use."""
119+
app_instance = memory_expiration_app.app_instance
120+
assert app_instance is not None
121+
122+
token_input = AppHarness.poll_for_or_raise_timeout(
123+
lambda: driver.find_element(By.ID, "token")
124+
)
125+
token = memory_expiration_app.poll_for_value(token_input)
126+
assert token is not None
127+
128+
counter = driver.find_element(By.ID, "counter")
129+
increment = driver.find_element(By.ID, "increment")
130+
app_state_manager = app_instance.state_manager
131+
assert isinstance(app_state_manager, StateManagerMemory)
132+
133+
AppHarness.expect(lambda: counter.text == "0")
134+
135+
increment.click()
136+
AppHarness.expect(lambda: counter.text == "1")
137+
AppHarness.expect(lambda: token in app_state_manager.states)
138+
139+
time.sleep(0.6)
140+
increment.click()
141+
AppHarness.expect(lambda: counter.text == "2")
142+
AppHarness.expect(lambda: token in app_state_manager.states)
143+
144+
time.sleep(0.6)
145+
increment.click()
146+
AppHarness.expect(lambda: counter.text == "3")
147+
AppHarness.expect(lambda: token in app_state_manager.states)
148+
149+
time.sleep(0.6)
150+
assert token in app_state_manager.states
151+
assert counter.text == "3"
152+
153+
AppHarness.expect(lambda: token not in app_state_manager.states, timeout=5)
154+
155+
increment.click()
156+
AppHarness.expect(lambda: counter.text == "1")
157+
assert token_input.get_attribute("value") == token

0 commit comments

Comments
 (0)