diff --git a/reflex/app.py b/reflex/app.py index e384873dcc4..ad8a951245a 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -120,7 +120,7 @@ ) from reflex.utils.imports import ImportVar from reflex.utils.misc import run_in_thread -from reflex.utils.token_manager import TokenManager +from reflex.utils.token_manager import RedisTokenManager, TokenManager from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send if TYPE_CHECKING: @@ -2033,11 +2033,13 @@ def __init__(self, namespace: str, app: App): self._token_manager = TokenManager.create() @property - def token_to_sid(self) -> dict[str, str]: + def token_to_sid(self) -> Mapping[str, str]: """Get token to SID mapping for backward compatibility. + Note: this mapping is read-only. + Returns: - The token to SID mapping dict. + The token to SID mapping. """ # For backward compatibility, expose the underlying dict return self._token_manager.token_to_sid @@ -2059,6 +2061,9 @@ async def on_connect(self, sid: str, environ: dict): sid: The Socket.IO session id. environ: The request information, including HTTP headers. """ + if isinstance(self._token_manager, RedisTokenManager): + # Make sure this instance is watching for updates from other instances. + self._token_manager.ensure_lost_and_found_task(self.emit_update) query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", "")) token_list = query_params.get("token", []) if token_list: @@ -2072,11 +2077,14 @@ async def on_connect(self, sid: str, environ: dict): f"Frontend version {subprotocol} for session {sid} does not match the backend version {constants.Reflex.VERSION}." ) - def on_disconnect(self, sid: str): + def on_disconnect(self, sid: str) -> asyncio.Task | None: """Event for when the websocket disconnects. Args: sid: The Socket.IO session id. + + Returns: + An asyncio Task for cleaning up the token, or None. """ # Get token before cleaning up disconnect_token = self.sid_to_token.get(sid) @@ -2091,6 +2099,8 @@ def on_disconnect(self, sid: str): lambda t: t.exception() and console.error(f"Token cleanup error: {t.exception()}") ) + return task + return None async def emit_update(self, update: StateUpdate, token: str) -> None: """Emit an update to the client. @@ -2100,16 +2110,30 @@ async def emit_update(self, update: StateUpdate, token: str) -> None: token: The client token (tab) associated with the event. """ client_token, _ = _split_substate_key(token) - sid = self.token_to_sid.get(client_token) - if sid is None: - # If the sid is None, we are not connected to a client. Prevent sending - # updates to all clients. - console.warn(f"Attempting to send delta to disconnected client {token!r}") + socket_record = self._token_manager.token_to_socket.get(client_token) + if ( + socket_record is None + or socket_record.instance_id != self._token_manager.instance_id + ): + if isinstance(self._token_manager, RedisTokenManager): + # The socket belongs to another instance of the app, send it to the lost and found. + if not await self._token_manager.emit_lost_and_found( + client_token, update + ): + console.warn( + f"Failed to send delta to lost and found for client {token!r}" + ) + else: + # If the socket record is None, we are not connected to a client. Prevent sending + # updates to all clients. + console.warn( + f"Attempting to send delta to disconnected client {token!r}" + ) return # Creating a task prevents the update from being blocked behind other coroutines. await asyncio.create_task( - self.emit(str(constants.SocketEvent.EVENT), update, to=sid), - name=f"reflex_emit_event|{token}|{sid}|{time.time()}", + self.emit(str(constants.SocketEvent.EVENT), update, to=socket_record.sid), + name=f"reflex_emit_event|{token}|{socket_record.sid}|{time.time()}", ) async def on_event(self, sid: str, data: Any): diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 6bfa78a31ea..5c269b1fb41 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -67,6 +67,7 @@ class StateManagerRedis(StateManager): # The keyspace subscription string when redis is waiting for lock to be released. _redis_notify_keyspace_events: str = dataclasses.field( default="K" # Enable keyspace notifications (target a particular key) + "$" # For String commands (like setting keys) "g" # For generic commands (DEL, EXPIRE, etc) "x" # For expired events "e" # For evicted events (i.e. maxmemory exceeded) @@ -76,7 +77,6 @@ class StateManagerRedis(StateManager): _redis_keyspace_lock_release_events: set[bytes] = dataclasses.field( default_factory=lambda: { b"del", - b"expire", b"expired", b"evicted", } diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index b60b77e9743..9ec7aa2c53d 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -2,10 +2,17 @@ from __future__ import annotations +import asyncio +import dataclasses +import json import uuid from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from collections.abc import AsyncIterator, Callable, Coroutine +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, ClassVar +from reflex.istate.manager.redis import StateManagerRedis +from reflex.state import BaseState, StateUpdate from reflex.utils import console, prerequisites if TYPE_CHECKING: @@ -21,16 +28,54 @@ def _get_new_token() -> str: return str(uuid.uuid4()) +@dataclasses.dataclass(frozen=True, kw_only=True) +class SocketRecord: + """Record for a connected socket client.""" + + instance_id: str + sid: str + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class LostAndFoundRecord: + """Record for a StateUpdate for a token with its socket on another instance.""" + + token: str + update: dict[str, Any] + + class TokenManager(ABC): """Abstract base class for managing client token to session ID mappings.""" def __init__(self): """Initialize the token manager with local dictionaries.""" - # Keep a mapping between socket ID and client token. - self.token_to_sid: dict[str, str] = {} + # Each process has an instance_id to identify its own sockets. + self.instance_id: str = _get_new_token() # Keep a mapping between client token and socket ID. + self.token_to_socket: dict[str, SocketRecord] = {} + # Keep a mapping between socket ID and client token. self.sid_to_token: dict[str, str] = {} + @property + def token_to_sid(self) -> MappingProxyType[str, str]: + """Read-only compatibility property for token_to_socket mapping. + + Returns: + The token to session ID mapping. + """ + return MappingProxyType({ + token: sr.sid for token, sr in self.token_to_socket.items() + }) + + async def enumerate_tokens(self) -> AsyncIterator[str]: + """Iterate over all tokens in the system. + + Yields: + All client tokens known to the TokenManager. + """ + for token in self.token_to_socket: + yield token + @abstractmethod async def link_token_to_sid(self, token: str, sid: str) -> str | None: """Link a token to a session ID. @@ -68,7 +113,9 @@ def create(cls) -> TokenManager: async def disconnect_all(self): """Disconnect all tracked tokens when the server is going down.""" - token_sid_pairs: set[tuple[str, str]] = set(self.token_to_sid.items()) + token_sid_pairs: set[tuple[str, str]] = { + (token, sr.sid) for token, sr in self.token_to_socket.items() + } token_sid_pairs.update( ((token, sid) for sid, token in self.sid_to_token.items()) ) @@ -95,14 +142,20 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None: New token if duplicate detected and new token generated, None otherwise. """ # Check if token is already mapped to a different SID (duplicate tab) - if token in self.token_to_sid and sid != self.token_to_sid.get(token): + if ( + socket_record := self.token_to_socket.get(token) + ) is not None and sid != socket_record.sid: new_token = _get_new_token() - self.token_to_sid[new_token] = sid + self.token_to_socket[new_token] = SocketRecord( + instance_id=self.instance_id, sid=sid + ) self.sid_to_token[sid] = new_token return new_token # Normal case - link token to SID - self.token_to_sid[token] = sid + self.token_to_socket[token] = SocketRecord( + instance_id=self.instance_id, sid=sid + ) self.sid_to_token[sid] = token return None @@ -114,7 +167,7 @@ async def disconnect_token(self, token: str, sid: str) -> None: sid: The Socket.IO session ID. """ # Clean up both mappings - self.token_to_sid.pop(token, None) + self.token_to_socket.pop(token, None) self.sid_to_token.pop(sid, None) @@ -125,6 +178,8 @@ class RedisTokenManager(LocalTokenManager): for cross-worker duplicate detection. """ + _token_socket_record_prefix: ClassVar[str] = "token_manager_socket_record_" + def __init__(self, redis: Redis): """Initialize the Redis token manager. @@ -142,6 +197,10 @@ def __init__(self, redis: Redis): config = get_config() self.token_expiration = config.redis_token_expiration + # Pub/sub tasks for handling sockets owned by other instances. + self._socket_record_task: asyncio.Task | None = None + self._lost_and_found_task: asyncio.Task | None = None + def _get_redis_key(self, token: str) -> str: """Get Redis key for token mapping. @@ -149,9 +208,78 @@ def _get_redis_key(self, token: str) -> str: token: The client token. Returns: - Redis key following Reflex conventions: {token}_sid + Redis key following Reflex conventions: token_manager_socket_record_{token} + """ + return f"{self._token_socket_record_prefix}{token}" + + async def enumerate_tokens(self) -> AsyncIterator[str]: + """Iterate over all tokens in the system. + + Yields: + All client tokens known to the RedisTokenManager. """ - return f"{token}_sid" + cursor = 0 + while scan_result := await self.redis.scan( + cursor=cursor, match=self._get_redis_key("*") + ): + cursor = int(scan_result[0]) + for key in scan_result[1]: + yield key.decode().replace(self._token_socket_record_prefix, "") + if not cursor: + break + + def _handle_socket_record_del(self, token: str) -> None: + """Handle deletion of a socket record from Redis. + + Args: + token: The client token whose record was deleted. + """ + if ( + socket_record := self.token_to_socket.pop(token, None) + ) is not None and socket_record.instance_id != self.instance_id: + self.sid_to_token.pop(socket_record.sid, None) + + async def _subscribe_socket_record_updates(self, redis_db: int) -> None: + """Subscribe to Redis keyspace notifications for socket record updates.""" + async with self.redis.pubsub() as pubsub: + await pubsub.psubscribe( + f"__keyspace@{redis_db}__:{self._get_redis_key('*')}" + ) + async for message in pubsub.listen(): + if message["type"] == "pmessage": + key = message["channel"].split(b":", 1)[1].decode() + token = key.replace(self._token_socket_record_prefix, "") + + if token not in self.token_to_socket: + # We don't know about this token, skip + continue + + event = message["data"].decode() + if event in ("del", "expired", "evicted"): + self._handle_socket_record_del(token) + elif event == "set": + await self._get_token_owner(token, refresh=True) + + async def _socket_record_updates_forever(self) -> None: + """Background task to monitor Redis keyspace notifications for socket record updates.""" + await StateManagerRedis( + state=BaseState, redis=self.redis + )._enable_keyspace_notifications() + redis_db = self.redis.get_connection_kwargs().get("db", 0) + while True: + try: + await self._subscribe_socket_record_updates(redis_db) + except asyncio.CancelledError: # noqa: PERF203 + break + except Exception as e: + console.error(f"RedisTokenManager socket record update task error: {e}") + + def _ensure_socket_record_task(self) -> None: + """Ensure the socket record updates subscriber task is running.""" + if self._socket_record_task is None or self._socket_record_task.done(): + self._socket_record_task = asyncio.create_task( + self._socket_record_updates_forever() + ) async def link_token_to_sid(self, token: str, sid: str) -> str | None: """Link a token to a session ID with Redis-based duplicate detection. @@ -164,9 +292,14 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None: New token if duplicate detected and new token generated, None otherwise. """ # Fast local check first (handles reconnections) - if token in self.token_to_sid and self.token_to_sid[token] == sid: + if ( + socket_record := self.token_to_socket.get(token) + ) is not None and sid == socket_record.sid: return None # Same token, same SID = reconnection, no Redis check needed + # Make sure the update subscriber is running + self._ensure_socket_record_task() + # Check Redis for cross-worker duplicates redis_key = self._get_redis_key(token) @@ -176,34 +309,29 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None: console.error(f"Redis error checking token existence: {e}") return await super().link_token_to_sid(token, sid) + new_token = None if token_exists_in_redis: # Duplicate exists somewhere - generate new token - new_token = _get_new_token() - new_redis_key = self._get_redis_key(new_token) - - try: - # Store in Redis - await self.redis.set(new_redis_key, "1", ex=self.token_expiration) - except Exception as e: - console.error(f"Redis error storing new token: {e}") - # Still update local dicts and continue + token = new_token = _get_new_token() + redis_key = self._get_redis_key(new_token) - # Store in local dicts (always do this) - self.token_to_sid[new_token] = sid - self.sid_to_token[sid] = new_token - return new_token + # Store in local dicts + socket_record = self.token_to_socket[token] = SocketRecord( + instance_id=self.instance_id, sid=sid + ) + self.sid_to_token[sid] = token - # Normal case - store in both Redis and local dicts + # Store in Redis if possible try: - await self.redis.set(redis_key, "1", ex=self.token_expiration) + await self.redis.set( + redis_key, + json.dumps(dataclasses.asdict(socket_record)), + ex=self.token_expiration, + ) except Exception as e: console.error(f"Redis error storing token: {e}") - # Continue with local storage - - # Store in local dicts (always do this) - self.token_to_sid[token] = sid - self.sid_to_token[sid] = token - return None + # Return the new token if one was generated + return new_token async def disconnect_token(self, token: str, sid: str) -> None: """Clean up token mapping when client disconnects. @@ -213,7 +341,11 @@ async def disconnect_token(self, token: str, sid: str) -> None: sid: The Socket.IO session ID. """ # Only clean up if we own it locally (fast ownership check) - if self.token_to_sid.get(token) == sid: + if ( + (socket_record := self.token_to_socket.get(token)) is not None + and socket_record.sid == sid + and socket_record.instance_id == self.instance_id + ): # Clean up Redis redis_key = self._get_redis_key(token) try: @@ -223,3 +355,124 @@ async def disconnect_token(self, token: str, sid: str) -> None: # Clean up local dicts (always do this) await super().disconnect_token(token, sid) + + @staticmethod + def _get_lost_and_found_key(instance_id: str) -> str: + """Get the Redis key for lost and found deltas for an instance. + + Args: + instance_id: The instance ID. + + Returns: + The Redis key for lost and found deltas. + """ + return f"token_manager_lost_and_found_{instance_id}" + + async def _subscribe_lost_and_found_updates( + self, + emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]], + ) -> None: + """Subscribe to Redis channel notifications for lost and found deltas. + + Args: + emit_update: The function to emit state updates. + """ + async with self.redis.pubsub() as pubsub: + await pubsub.psubscribe( + f"channel:{self._get_lost_and_found_key(self.instance_id)}" + ) + async for message in pubsub.listen(): + if message["type"] == "pmessage": + record = LostAndFoundRecord(**json.loads(message["data"].decode())) + await emit_update(StateUpdate(**record.update), record.token) + + async def _lost_and_found_updates_forever( + self, + emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]], + ): + """Background task to monitor Redis lost and found deltas. + + Args: + emit_update: The function to emit state updates. + """ + while True: + try: + await self._subscribe_lost_and_found_updates(emit_update) + except asyncio.CancelledError: # noqa: PERF203 + break + except Exception as e: + console.error(f"RedisTokenManager lost and found task error: {e}") + + def ensure_lost_and_found_task( + self, + emit_update: Callable[[StateUpdate, str], Coroutine[None, None, None]], + ) -> None: + """Ensure the lost and found subscriber task is running. + + Args: + emit_update: The function to emit state updates. + """ + if self._lost_and_found_task is None or self._lost_and_found_task.done(): + self._lost_and_found_task = asyncio.create_task( + self._lost_and_found_updates_forever(emit_update) + ) + + async def _get_token_owner(self, token: str, refresh: bool = False) -> str | None: + """Get the instance ID of the owner of a token. + + Args: + token: The client token. + refresh: Whether to fetch the latest record from Redis. + + Returns: + The instance ID of the owner, or None if not found. + """ + if ( + not refresh + and (socket_record := self.token_to_socket.get(token)) is not None + ): + return socket_record.instance_id + + redis_key = self._get_redis_key(token) + try: + record_json = await self.redis.get(redis_key) + if record_json: + record_data = json.loads(record_json) + socket_record = SocketRecord(**record_data) + self.token_to_socket[token] = socket_record + self.sid_to_token[socket_record.sid] = token + return socket_record.instance_id + console.warn(f"Redis token owner not found for token {token}") + except Exception as e: + console.error(f"Redis error getting token owner: {e}") + return None + + async def emit_lost_and_found( + self, + token: str, + update: StateUpdate, + ) -> bool: + """Emit a lost and found delta to Redis. + + Args: + token: The client token. + update: The state update. + + Returns: + True if the delta was published, False otherwise. + """ + # See where this update belongs + owner_instance_id = await self._get_token_owner(token) + if owner_instance_id is None: + return False + record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update)) + try: + await self.redis.publish( + f"channel:{self._get_lost_and_found_key(owner_instance_id)}", + json.dumps(dataclasses.asdict(record)), + ) + except Exception as e: + console.error(f"Redis error publishing lost and found delta: {e}") + else: + return True + return False diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 9737d4ac2bb..885edae0ab1 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -170,7 +170,7 @@ async def test_connection_banner(connection_banner: AppHarness): await connection_banner.state_manager.redis.get( app_token_manager._get_redis_key(token) ) - == b"1" + == f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_before}"}}'.encode() ) delay_button = driver.find_element(By.ID, "delay") @@ -221,17 +221,17 @@ async def test_connection_banner(connection_banner: AppHarness): # After reconnecting, the token association should be re-established. app_token_manager = connection_banner.token_manager() + # Make sure the new connection has a different websocket sid. + sid_after = app_token_manager.token_to_sid[token] + assert sid_before != sid_after if isinstance(connection_banner.state_manager, StateManagerRedis): assert isinstance(app_token_manager, RedisTokenManager) assert ( await connection_banner.state_manager.redis.get( app_token_manager._get_redis_key(token) ) - == b"1" + == f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_after}"}}'.encode() ) - # Make sure the new connection has a different websocket sid. - sid_after = app_token_manager.token_to_sid[token] - assert sid_before != sid_after # Count should have incremented after coming back up assert connection_banner.poll_for_value(counter_element, exp_not_equal="1") == "2" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index b3bab332816..8a0995c9b03 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -53,6 +53,7 @@ UnretrievableVarValueError, ) from reflex.utils.format import json_dumps +from reflex.utils.token_manager import SocketRecord from reflex.vars.base import Var, computed_var from .states import GenState @@ -2016,7 +2017,10 @@ async def test_state_proxy( namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[router_data.session.session_id] = token - namespace.token_to_sid[token] = router_data.session.session_id + namespace._token_manager.instance_id = "mock" + namespace._token_manager.token_to_socket[token] = SocketRecord( + instance_id="mock", sid=router_data.session.session_id + ) if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): mock_app.state_manager.states[parent_state.router.session.client_token] = ( parent_state @@ -2227,7 +2231,10 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): namespace = mock_app.event_namespace assert namespace is not None namespace.sid_to_token[sid] = token - namespace.token_to_sid[token] = sid + namespace._token_manager.instance_id = "mock" + namespace._token_manager.token_to_socket[token] = SocketRecord( + instance_id="mock", sid=sid + ) mock_app.state_manager.state = mock_app._state = BackgroundTaskState async for update in rx.app.process( mock_app, diff --git a/tests/units/utils/test_token_manager.py b/tests/units/utils/test_token_manager.py index a95544ec844..d9d891e3253 100644 --- a/tests/units/utils/test_token_manager.py +++ b/tests/units/utils/test_token_manager.py @@ -1,12 +1,21 @@ """Unit tests for TokenManager implementations.""" +import asyncio +import json +import time +from collections.abc import Callable, Generator +from contextlib import asynccontextmanager from unittest.mock import AsyncMock, Mock, patch import pytest +from reflex import config +from reflex.app import EventNamespace +from reflex.state import StateUpdate from reflex.utils.token_manager import ( LocalTokenManager, RedisTokenManager, + SocketRecord, TokenManager, ) @@ -61,6 +70,7 @@ def test_create_redis_when_redis_available( """ mock_check_redis_used.return_value = True mock_redis_client = Mock() + mock_redis_client.get_connection_kwargs.return_value = {"db": 0} mock_get_redis.return_value = mock_redis_client manager = TokenManager.create() @@ -174,6 +184,41 @@ async def test_disconnect_nonexistent_token(self, manager): assert len(manager.token_to_sid) == 0 assert len(manager.sid_to_token) == 0 + async def test_enumerate_tokens(self, manager): + """Test enumerate_tokens yields all linked tokens. + + Args: + manager: LocalTokenManager fixture instance. + """ + tokens_sids = [("token1", "sid1"), ("token2", "sid2"), ("token3", "sid3")] + + for token, sid in tokens_sids: + await manager.link_token_to_sid(token, sid) + + found_tokens = set() + async for token in manager.enumerate_tokens(): + found_tokens.add(token) + + expected_tokens = {token for token, _ in tokens_sids} + assert found_tokens == expected_tokens + + # Disconnect a token and ensure it's removed. + await manager.disconnect_token("token2", "sid2") + expected_tokens.remove("token2") + + found_tokens = set() + async for token in manager.enumerate_tokens(): + found_tokens.add(token) + + assert found_tokens == expected_tokens + + # Disconnect all tokens, none should remain + await manager.disconnect_all() + found_tokens = set() + async for token in manager.enumerate_tokens(): + found_tokens.add(token) + assert not found_tokens + class TestRedisTokenManager: """Tests for RedisTokenManager.""" @@ -189,6 +234,24 @@ def mock_redis(self): redis.exists = AsyncMock() redis.set = AsyncMock() redis.delete = AsyncMock() + + # Non-async call + redis.get_connection_kwargs = Mock(return_value={"db": 0}) + + # Mock out pubsub + async def listen(): + await asyncio.sleep(1) + if False: + yield + return + + @asynccontextmanager + async def pubsub(): # noqa: RUF029 + pubsub_mock = AsyncMock() + pubsub_mock.listen = listen + yield pubsub_mock + + redis.pubsub = pubsub return redis @pytest.fixture @@ -215,7 +278,7 @@ def test_get_redis_key(self, manager): manager: RedisTokenManager fixture instance. """ token = "test_token_123" - expected_key = f"{token}_sid" + expected_key = f"token_manager_socket_record_{token}" assert manager._get_redis_key(token) == expected_key @@ -232,9 +295,15 @@ async def test_link_token_to_sid_normal_case(self, manager, mock_redis): result = await manager.link_token_to_sid(token, sid) assert result is None - mock_redis.exists.assert_called_once_with(f"{token}_sid") - mock_redis.set.assert_called_once_with(f"{token}_sid", "1", ex=3600) - assert manager.token_to_sid[token] == sid + mock_redis.exists.assert_called_once_with( + f"token_manager_socket_record_{token}" + ) + mock_redis.set.assert_called_once_with( + f"token_manager_socket_record_{token}", + json.dumps({"instance_id": manager.instance_id, "sid": sid}), + ex=3600, + ) + assert manager.token_to_socket[token].sid == sid assert manager.sid_to_token[sid] == token async def test_link_token_to_sid_reconnection_skips_redis( @@ -247,7 +316,9 @@ async def test_link_token_to_sid_reconnection_skips_redis( mock_redis: Mock Redis client fixture. """ token, sid = "token1", "sid1" - manager.token_to_sid[token] = sid + manager.token_to_socket[token] = SocketRecord( + instance_id=manager.instance_id, sid=sid + ) result = await manager.link_token_to_sid(token, sid) @@ -271,8 +342,14 @@ async def test_link_token_to_sid_duplicate_detected(self, manager, mock_redis): assert result != token assert len(result) == 36 # UUID4 length - mock_redis.exists.assert_called_once_with(f"{token}_sid") - mock_redis.set.assert_called_once_with(f"{result}_sid", "1", ex=3600) + mock_redis.exists.assert_called_once_with( + f"token_manager_socket_record_{token}" + ) + mock_redis.set.assert_called_once_with( + f"token_manager_socket_record_{result}", + json.dumps({"instance_id": manager.instance_id, "sid": sid}), + ex=3600, + ) assert manager.token_to_sid[result] == sid assert manager.sid_to_token[sid] == result @@ -323,12 +400,16 @@ async def test_disconnect_token_owned_locally(self, manager, mock_redis): mock_redis: Mock Redis client fixture. """ token, sid = "token1", "sid1" - manager.token_to_sid[token] = sid + manager.token_to_socket[token] = SocketRecord( + instance_id=manager.instance_id, sid=sid + ) manager.sid_to_token[sid] = token await manager.disconnect_token(token, sid) - mock_redis.delete.assert_called_once_with(f"{token}_sid") + mock_redis.delete.assert_called_once_with( + f"token_manager_socket_record_{token}" + ) assert token not in manager.token_to_sid assert sid not in manager.sid_to_token @@ -353,7 +434,9 @@ async def test_disconnect_token_redis_error(self, manager, mock_redis): mock_redis: Mock Redis client fixture. """ token, sid = "token1", "sid1" - manager.token_to_sid[token] = sid + manager.token_to_socket[token] = SocketRecord( + instance_id=manager.instance_id, sid=sid + ) manager.sid_to_token[sid] = token mock_redis.delete.side_effect = Exception("Redis delete error") @@ -402,3 +485,188 @@ def test_inheritance_from_local_manager(self, manager): assert isinstance(manager, LocalTokenManager) assert hasattr(manager, "token_to_sid") assert hasattr(manager, "sid_to_token") + + +@pytest.fixture +def redis_url(): + """Returns the Redis URL from the environment.""" + redis_url = config.get_config().redis_url + if redis_url is None: + pytest.skip("Redis URL not configured") + return redis_url + + +def query_string_for(token: str) -> dict[str, str]: + """Generate query string for given token. + + Args: + token: The token to generate query string for. + + Returns: + The generated query string. + """ + return {"QUERY_STRING": f"token={token}"} + + +@pytest.fixture +def event_namespace_factory() -> Generator[Callable[[], EventNamespace], None, None]: + """Yields the EventNamespace factory function.""" + namespace = config.get_config().get_event_namespace() + created_objs = [] + + def new_event_namespace() -> EventNamespace: + state = Mock() + state.router_data = {} + + mock_app = Mock() + mock_app.modify_state = Mock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=state)) + ) + + event_namespace = EventNamespace(namespace=namespace, app=mock_app) + event_namespace.emit = AsyncMock() + created_objs.append(event_namespace) + return event_namespace + + yield new_event_namespace + + for obj in created_objs: + asyncio.run(obj._token_manager.disconnect_all()) + + +@pytest.mark.usefixtures("redis_url") +@pytest.mark.asyncio +async def test_redis_token_manager_enumerate_tokens( + event_namespace_factory: Callable[[], EventNamespace], +): + """Integration test for RedisTokenManager enumerate_tokens interface. + + Should support enumerating tokens across separate instances of the + RedisTokenManager. + + Args: + event_namespace_factory: Factory fixture for EventNamespace instances. + """ + event_namespace1 = event_namespace_factory() + event_namespace2 = event_namespace_factory() + + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) + + found_tokens = set() + async for token in event_namespace1._token_manager.enumerate_tokens(): + found_tokens.add(token) + + assert "token1" in found_tokens + assert "token2" in found_tokens + assert len(found_tokens) == 2 + + await event_namespace1._token_manager.disconnect_all() + + found_tokens = set() + async for token in event_namespace1._token_manager.enumerate_tokens(): + found_tokens.add(token) + assert "token2" in found_tokens + assert len(found_tokens) == 1 + + await event_namespace2._token_manager.disconnect_all() + + found_tokens = set() + async for token in event_namespace1._token_manager.enumerate_tokens(): + found_tokens.add(token) + assert not found_tokens + + +@pytest.mark.usefixtures("redis_url") +@pytest.mark.asyncio +async def test_redis_token_manager_get_token_owner( + event_namespace_factory: Callable[[], EventNamespace], +): + """Integration test for RedisTokenManager get_token_owner interface. + + Should support retrieving the owner of a token across separate instances of the + RedisTokenManager. + + Args: + event_namespace_factory: Factory fixture for EventNamespace instances. + """ + event_namespace1 = event_namespace_factory() + event_namespace2 = event_namespace_factory() + + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) + + assert isinstance((manager1 := event_namespace1._token_manager), RedisTokenManager) + assert isinstance((manager2 := event_namespace2._token_manager), RedisTokenManager) + + assert await manager1._get_token_owner("token1") == manager1.instance_id + assert await manager1._get_token_owner("token2") == manager2.instance_id + assert await manager2._get_token_owner("token1") == manager1.instance_id + assert await manager2._get_token_owner("token2") == manager2.instance_id + + +async def _wait_for_call_count_positive(mock: Mock, timeout: float = 5.0): + """Wait until the mock's call count is positive. + + Args: + mock: The mock to wait on. + timeout: The maximum time to wait in seconds. + """ + deadline = time.monotonic() + timeout + while mock.call_count == 0 and time.monotonic() < deadline: # noqa: ASYNC110 + await asyncio.sleep(0.1) + + +@pytest.mark.usefixtures("redis_url") +@pytest.mark.asyncio +async def test_redis_token_manager_lost_and_found( + event_namespace_factory: Callable[[], EventNamespace], +): + """Updates emitted for lost and found tokens should be routed correctly via redis. + + Args: + event_namespace_factory: Factory fixture for EventNamespace instances. + """ + event_namespace1 = event_namespace_factory() + emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType] + event_namespace2 = event_namespace_factory() + emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType] + + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) + + await event_namespace2.emit_update(StateUpdate(), token="token1") + await _wait_for_call_count_positive(emit1_mock) + emit2_mock.assert_not_called() + emit1_mock.assert_called_once() + emit1_mock.reset_mock() + + await event_namespace2.emit_update(StateUpdate(), token="token2") + await _wait_for_call_count_positive(emit2_mock) + emit1_mock.assert_not_called() + emit2_mock.assert_called_once() + emit2_mock.reset_mock() + + if task := event_namespace1.on_disconnect(sid="sid1"): + await task + await event_namespace2.emit_update(StateUpdate(), token="token1") + # Update should be dropped on the floor. + await asyncio.sleep(2) + emit1_mock.assert_not_called() + emit2_mock.assert_not_called() + + await event_namespace2.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.emit_update(StateUpdate(), token="token1") + await _wait_for_call_count_positive(emit2_mock) + emit1_mock.assert_not_called() + emit2_mock.assert_called_once() + emit2_mock.reset_mock() + + if task := event_namespace2.on_disconnect(sid="sid1"): + await task + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.emit_update(StateUpdate(), token="token1") + await _wait_for_call_count_positive(emit1_mock) + emit2_mock.assert_not_called() + emit1_mock.assert_called_once() + emit1_mock.reset_mock()