diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index fb24a4e5d1a..f3bccedf5fa 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -229,16 +229,22 @@ async def enumerate_tokens(self) -> AsyncIterator[str]: if not cursor: break - def _handle_socket_record_del(self, token: str) -> None: + async def _handle_socket_record_del( + self, token: str, expired: bool = False + ) -> None: """Handle deletion of a socket record from Redis. Args: token: The client token whose record was deleted. + expired: Whether the deletion was due to expiration. """ if ( socket_record := self.token_to_socket.pop(token, None) - ) is not None and socket_record.instance_id != self.instance_id: + ) is not None and socket_record.instance_id == self.instance_id: self.sid_to_token.pop(socket_record.sid, None) + if expired: + # Keep the record alive as long as this process is alive and not deleted. + await self.link_token_to_sid(token, socket_record.sid) async def _subscribe_socket_record_updates(self) -> None: """Subscribe to Redis keyspace notifications for socket record updates.""" @@ -262,7 +268,10 @@ async def _subscribe_socket_record_updates(self) -> None: event = message["data"].decode() if event in ("del", "expired", "evicted"): - self._handle_socket_record_del(token) + await self._handle_socket_record_del( + token, + expired=(event == "expired"), + ) elif event == "set": await self._get_token_owner(token, refresh=True)