Skip to content

Commit f3e393e

Browse files
fix: only open one connection/sub for each token per worker
bonus: properly cleanup StateManager connections on disconnect
1 parent 0bf8ffe commit f3e393e

2 files changed

Lines changed: 44 additions & 3 deletions

File tree

reflex/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,7 +1479,7 @@ def __init__(self, namespace: str, app: App):
14791479
self.sid_to_token = {}
14801480
self.app = app
14811481

1482-
def on_connect(self, sid, environ):
1482+
async def on_connect(self, sid, environ):
14831483
"""Event for when the websocket is connected.
14841484
14851485
Args:
@@ -1488,7 +1488,7 @@ def on_connect(self, sid, environ):
14881488
"""
14891489
pass
14901490

1491-
def on_disconnect(self, sid):
1491+
async def on_disconnect(self, sid):
14921492
"""Event for when the websocket disconnects.
14931493
14941494
Args:
@@ -1497,6 +1497,7 @@ def on_disconnect(self, sid):
14971497
disconnect_token = self.sid_to_token.pop(sid, None)
14981498
if disconnect_token:
14991499
self.token_to_sid.pop(disconnect_token, None)
1500+
await self.app.state_manager.disconnect(sid)
15001501

15011502
async def emit_update(self, update: StateUpdate, sid: str) -> None:
15021503
"""Emit an update to the client.

reflex/state.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2830,6 +2830,14 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28302830
"""
28312831
yield self.state()
28322832

2833+
async def disconnect(self, token: str) -> None:
2834+
"""Disconnect the client with the given token.
2835+
2836+
Args:
2837+
token: The token to disconnect.
2838+
"""
2839+
pass
2840+
28332841

28342842
class StateManagerMemory(StateManager):
28352843
"""A state manager that stores states in memory."""
@@ -2899,6 +2907,20 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
28992907
yield state
29002908
await self.set_state(token, state)
29012909

2910+
@override
2911+
async def disconnect(self, token: str) -> None:
2912+
"""Disconnect the client with the given token.
2913+
2914+
Args:
2915+
token: The token to disconnect.
2916+
"""
2917+
if token in self.states:
2918+
del self.states[token]
2919+
if lock := self._states_locks.get(token):
2920+
if lock.locked():
2921+
lock.release()
2922+
del self._states_locks[token]
2923+
29022924

29032925
def _default_token_expiration() -> int:
29042926
"""Get the default token expiration time.
@@ -3187,6 +3209,9 @@ class StateManagerRedis(StateManager):
31873209
b"evicted",
31883210
}
31893211

3212+
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
3213+
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
3214+
31903215
async def _get_parent_state(
31913216
self, token: str, state: BaseState | None = None
31923217
) -> BaseState | None:
@@ -3462,7 +3487,9 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34623487
# Some redis servers only allow out-of-band configuration, so ignore errors here.
34633488
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
34643489
raise
3465-
async with self.redis.pubsub() as pubsub:
3490+
if lock_key not in self._pubsub_locks:
3491+
self._pubsub_locks[lock_key] = asyncio.Lock()
3492+
async with self._pubsub_locks[lock_key], self.redis.pubsub() as pubsub:
34663493
await pubsub.psubscribe(lock_key_channel)
34673494
while not state_is_locked:
34683495
# wait for the lock to be released
@@ -3479,6 +3506,19 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
34793506
break
34803507
state_is_locked = await self._try_get_lock(lock_key, lock_id)
34813508

3509+
@override
3510+
async def disconnect(self, token: str):
3511+
"""Disconnect the token from the redis client.
3512+
3513+
Args:
3514+
token: The token to disconnect.
3515+
"""
3516+
lock_key = self._lock_key(token)
3517+
if lock := self._pubsub_locks.get(lock_key):
3518+
if lock.locked():
3519+
lock.release()
3520+
del self._pubsub_locks[lock_key]
3521+
34823522
@contextlib.asynccontextmanager
34833523
async def _lock(self, token: str):
34843524
"""Obtain a redis lock for a token.

0 commit comments

Comments
 (0)