@@ -36,8 +36,8 @@ class StateManagerMemory(StateManager):
3636 init = False ,
3737 )
3838
39- # The latest expiration deadline for each token .
40- _token_expires_at : dict [str , float ] = dataclasses .field (
39+ # The latest expiration deadline and token for each cache key .
40+ _token_expires_at : dict [str , tuple [ float , StateToken ] ] = dataclasses .field (
4141 default_factory = dict ,
4242 init = False ,
4343 )
@@ -53,7 +53,7 @@ def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE:
5353 Returns:
5454 The state for the token.
5555 """
56- key = token .ident if isinstance ( token , BaseStateToken ) else str ( token )
56+ key = token .cache_key
5757 if key not in self .states :
5858 if isinstance (token , BaseStateToken ):
5959 self .states [key ] = token .cls .get_root_state ()(
@@ -65,15 +65,21 @@ def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE:
6565
6666 def _track_token (self , token : StateToken ):
6767 """Refresh the expiration deadline for an active token."""
68- self ._token_expires_at [token .ident ] = time .time () + self .token_expiration
68+ self ._token_expires_at [token .cache_key ] = (
69+ time .time () + self .token_expiration ,
70+ token ,
71+ )
6972 self ._ensure_expiration_task ()
7073
7174 def _purge_token (self , token : StateToken ):
72- """Remove a token from in-memory state bookkeeping."""
73- key = token .ident if isinstance (token , BaseStateToken ) else str (token )
74- self ._token_expires_at .pop (token .ident , None )
75- self .states .pop (key , None )
76- self ._states_locks .pop (token .ident , None )
75+ """Remove a token from in-memory state bookkeeping.
76+
77+ Args:
78+ token: The token to purge.
79+ """
80+ self ._token_expires_at .pop (token .cache_key , None )
81+ self ._states_locks .pop (token .lock_key , None )
82+ self .states .pop (token .cache_key , None )
7783
7884 def _purge_expired_tokens (self ) -> float | None :
7985 """Purge expired in-memory state entries and return the next deadline.
@@ -86,15 +92,13 @@ def _purge_expired_tokens(self) -> float | None:
8692 token_expires_at = self ._token_expires_at
8793 state_locks = self ._states_locks
8894
89- for token , expires_at in list (token_expires_at .items ()):
95+ for _cache_key , ( expires_at , token ) in list (token_expires_at .items ()):
9096 if (
91- state_lock := state_locks .get (token )
97+ state_lock := state_locks .get (token . lock_key )
9298 ) is not None and state_lock .locked ():
9399 continue
94100 if expires_at <= now :
95- self ._purge_token (
96- BaseStateToken (ident = token , cls = type (self .states [token ]))
97- )
101+ self ._purge_token (token )
98102 continue
99103 if next_expires_at is None or expires_at < next_expires_at :
100104 next_expires_at = expires_at
@@ -110,12 +114,12 @@ async def _get_state_lock(self, token: StateToken) -> asyncio.Lock:
110114 Returns:
111115 The lock protecting the token's state.
112116 """
113- state_lock = self ._states_locks .get (token .ident )
117+ state_lock = self ._states_locks .get (token .lock_key )
114118 if state_lock is None :
115119 async with self ._state_manager_lock :
116- state_lock = self ._states_locks .get (token .ident )
120+ state_lock = self ._states_locks .get (token .lock_key )
117121 if state_lock is None :
118- state_lock = self ._states_locks [token .ident ] = asyncio .Lock ()
122+ state_lock = self ._states_locks [token .lock_key ] = asyncio .Lock ()
119123 return state_lock
120124
121125 async def _expire_states (self ):
@@ -166,8 +170,7 @@ async def set_state(
166170 state: The state to set.
167171 context: The state modification context.
168172 """
169- key = token .ident if isinstance (token , BaseStateToken ) else str (token )
170- self .states [key ] = state
173+ self .states [token .cache_key ] = state
171174 self ._track_token (token )
172175
173176 @override
0 commit comments