Skip to content

Commit 3a0f532

Browse files
committed
Add cache_key and lock_key attributes to StateToken
1 parent f6dc002 commit 3a0f532

2 files changed

Lines changed: 52 additions & 19 deletions

File tree

reflex/istate/manager/memory.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class StateManagerMemory(StateManager):
3636
init=False,
3737
)
3838

39-
# The latest expiration deadline for each token.
40-
_token_expires_at: dict[str, float] = dataclasses.field(
39+
# The latest expiration deadline and token for each cache key.
40+
_token_expires_at: dict[str, tuple[float, StateToken]] = dataclasses.field(
4141
default_factory=dict,
4242
init=False,
4343
)
@@ -53,7 +53,7 @@ def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE:
5353
Returns:
5454
The state for the token.
5555
"""
56-
key = token.ident if isinstance(token, BaseStateToken) else str(token)
56+
key = token.cache_key
5757
if key not in self.states:
5858
if isinstance(token, BaseStateToken):
5959
self.states[key] = token.cls.get_root_state()(
@@ -65,15 +65,21 @@ def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE:
6565

6666
def _track_token(self, token: StateToken):
6767
"""Refresh the expiration deadline for an active token."""
68-
self._token_expires_at[token.ident] = time.time() + self.token_expiration
68+
self._token_expires_at[token.cache_key] = (
69+
time.time() + self.token_expiration,
70+
token,
71+
)
6972
self._ensure_expiration_task()
7073

7174
def _purge_token(self, token: StateToken):
72-
"""Remove a token from in-memory state bookkeeping."""
73-
key = token.ident if isinstance(token, BaseStateToken) else str(token)
74-
self._token_expires_at.pop(token.ident, None)
75-
self.states.pop(key, None)
76-
self._states_locks.pop(token.ident, None)
75+
"""Remove a token from in-memory state bookkeeping.
76+
77+
Args:
78+
token: The token to purge.
79+
"""
80+
self._token_expires_at.pop(token.cache_key, None)
81+
self._states_locks.pop(token.lock_key, None)
82+
self.states.pop(token.cache_key, None)
7783

7884
def _purge_expired_tokens(self) -> float | None:
7985
"""Purge expired in-memory state entries and return the next deadline.
@@ -86,15 +92,13 @@ def _purge_expired_tokens(self) -> float | None:
8692
token_expires_at = self._token_expires_at
8793
state_locks = self._states_locks
8894

89-
for token, expires_at in list(token_expires_at.items()):
95+
for _cache_key, (expires_at, token) in list(token_expires_at.items()):
9096
if (
91-
state_lock := state_locks.get(token)
97+
state_lock := state_locks.get(token.lock_key)
9298
) is not None and state_lock.locked():
9399
continue
94100
if expires_at <= now:
95-
self._purge_token(
96-
BaseStateToken(ident=token, cls=type(self.states[token]))
97-
)
101+
self._purge_token(token)
98102
continue
99103
if next_expires_at is None or expires_at < next_expires_at:
100104
next_expires_at = expires_at
@@ -110,12 +114,12 @@ async def _get_state_lock(self, token: StateToken) -> asyncio.Lock:
110114
Returns:
111115
The lock protecting the token's state.
112116
"""
113-
state_lock = self._states_locks.get(token.ident)
117+
state_lock = self._states_locks.get(token.lock_key)
114118
if state_lock is None:
115119
async with self._state_manager_lock:
116-
state_lock = self._states_locks.get(token.ident)
120+
state_lock = self._states_locks.get(token.lock_key)
117121
if state_lock is None:
118-
state_lock = self._states_locks[token.ident] = asyncio.Lock()
122+
state_lock = self._states_locks[token.lock_key] = asyncio.Lock()
119123
return state_lock
120124

121125
async def _expire_states(self):
@@ -166,8 +170,7 @@ async def set_state(
166170
state: The state to set.
167171
context: The state modification context.
168172
"""
169-
key = token.ident if isinstance(token, BaseStateToken) else str(token)
170-
self.states[key] = state
173+
self.states[token.cache_key] = state
171174
self._track_token(token)
172175

173176
@override

reflex/istate/manager/token.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,24 @@ def with_cls(self, cls: type[TOKEN_TYPE]) -> Self:
3737
"""
3838
return dataclasses.replace(self, cls=cls)
3939

40+
@property
41+
def cache_key(self) -> str:
42+
"""The key used for caching state instances in the StateManager.
43+
44+
Returns:
45+
A string key combining ident and class path.
46+
"""
47+
return str(self)
48+
49+
@property
50+
def lock_key(self) -> str:
51+
"""The key used for locking and session-level bookkeeping.
52+
53+
Returns:
54+
The token ident.
55+
"""
56+
return self.ident
57+
4058
def __str__(self) -> str:
4159
"""The key used in the underlying StateManager store.
4260
@@ -109,6 +127,18 @@ class BaseStateToken(StateToken["BaseState"]):
109127
This token type implies subtree hierarchy population and other semantic checks.
110128
"""
111129

130+
@property
131+
def cache_key(self) -> str:
132+
"""The key used for caching state instances in the StateManager.
133+
134+
BaseState tokens use just the ident because the entire state hierarchy
135+
lives under a single root state instance per session.
136+
137+
Returns:
138+
The token ident.
139+
"""
140+
return self.ident
141+
112142
def with_cls(self, cls: type[BaseState]) -> Self:
113143
"""Return a new token with the cls field updated to the provided class.
114144

0 commit comments

Comments
 (0)