-
Notifications
You must be signed in to change notification settings - Fork 1.7k
fix duplicate tab issue #4607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
fix duplicate tab issue #4607
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
b113281
fix duplicate tab issue
Lendemor 41d8cfa
Merge branch 'main' into lendemor/fix_duplicate_tab_issue
Lendemor 2825500
review changes
Lendemor 2f215a3
fix issue
Lendemor d2b183e
Merge branch 'main' into lendemor/fix_duplicate_tab_issue
Lendemor 988eda0
Merge branch 'main' into lendemor/fix_duplicate_tab_issue
Lendemor 0eb6674
handle token_to_sid mapping via redis when enabled
Lendemor d9710ff
fix for precommit
Lendemor cb8c442
fix closed loop
Lendemor 475bf83
use once on get_redis
Lendemor 4c32f95
revert once
Lendemor c0fccfb
fix
Lendemor 5d0db8c
handle indexerror
Lendemor 9348c37
fix
Lendemor 989a48c
fix race condition
Lendemor File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| import json | ||
| import sys | ||
| import traceback | ||
| import urllib.parse | ||
| from collections.abc import ( | ||
| AsyncGenerator, | ||
| AsyncIterator, | ||
|
|
@@ -114,6 +115,7 @@ | |
| ) | ||
| from reflex.utils.exec import get_compile_context, is_prod_mode, is_testing_env | ||
| from reflex.utils.imports import ImportVar | ||
| from reflex.utils.token_manager import TokenManager | ||
| from reflex.utils.types import ASGIApp, Message, Receive, Scope, Send | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -1952,12 +1954,6 @@ class EventNamespace(AsyncNamespace): | |
| # The application object. | ||
| app: App | ||
|
|
||
| # Keep a mapping between socket ID and client token. | ||
| token_to_sid: dict[str, str] | ||
|
|
||
| # Keep a mapping between client token and socket ID. | ||
| sid_to_token: dict[str, str] | ||
|
|
||
| def __init__(self, namespace: str, app: App): | ||
| """Initialize the event namespace. | ||
|
|
||
|
|
@@ -1966,17 +1962,45 @@ def __init__(self, namespace: str, app: App): | |
| app: The application object. | ||
| """ | ||
| super().__init__(namespace) | ||
| self.token_to_sid = {} | ||
| self.sid_to_token = {} | ||
| self.app = app | ||
|
|
||
| def on_connect(self, sid: str, environ: dict): | ||
| # Use TokenManager for distributed duplicate tab prevention | ||
| self._token_manager = TokenManager.create() | ||
|
|
||
| @property | ||
| def token_to_sid(self) -> dict[str, str]: | ||
| """Get token to SID mapping for backward compatibility. | ||
|
|
||
| Returns: | ||
| The token to SID mapping dict. | ||
| """ | ||
| # For backward compatibility, expose the underlying dict | ||
| return self._token_manager.token_to_sid | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should raise a deprecation here and below |
||
|
|
||
| @property | ||
| def sid_to_token(self) -> dict[str, str]: | ||
| """Get SID to token mapping for backward compatibility. | ||
|
|
||
| Returns: | ||
| The SID to token mapping dict. | ||
| """ | ||
| # For backward compatibility, expose the underlying dict | ||
| return self._token_manager.sid_to_token | ||
|
|
||
| async def on_connect(self, sid: str, environ: dict): | ||
| """Event for when the websocket is connected. | ||
|
|
||
| Args: | ||
| sid: The Socket.IO session id. | ||
| environ: The request information, including HTTP headers. | ||
| """ | ||
| query_params = urllib.parse.parse_qs(environ.get("QUERY_STRING", "")) | ||
| token_list = query_params.get("token", []) | ||
| if token_list: | ||
| await self.link_token_to_sid(sid, token_list[0]) | ||
| else: | ||
| console.warn(f"No token provided in connection for session {sid}") | ||
|
|
||
| subprotocol = environ.get("HTTP_SEC_WEBSOCKET_PROTOCOL") | ||
| if subprotocol and subprotocol != constants.Reflex.VERSION: | ||
| console.warn( | ||
|
|
@@ -1989,9 +2013,18 @@ def on_disconnect(self, sid: str): | |
| Args: | ||
| sid: The Socket.IO session id. | ||
| """ | ||
| disconnect_token = self.sid_to_token.pop(sid, None) | ||
| # Get token before cleaning up | ||
| disconnect_token = self.sid_to_token.get(sid) | ||
| if disconnect_token: | ||
| self.token_to_sid.pop(disconnect_token, None) | ||
| # Use async cleanup through token manager | ||
| task = asyncio.create_task( | ||
| self._token_manager.disconnect_token(disconnect_token, sid) | ||
| ) | ||
| # Don't await to avoid blocking disconnect, but handle potential errors | ||
| task.add_done_callback( | ||
| lambda t: t.exception() | ||
| and console.error(f"Token cleanup error: {t.exception()}") | ||
| ) | ||
|
|
||
| async def emit_update(self, update: StateUpdate, sid: str) -> None: | ||
| """Emit an update to the client. | ||
|
|
@@ -2049,8 +2082,13 @@ async def on_event(self, sid: str, data: Any): | |
| msg = f"Failed to deserialize event data: {fields}." | ||
| raise exceptions.EventDeserializationError(msg) from ex | ||
|
|
||
| self.token_to_sid[event.token] = sid | ||
| self.sid_to_token[sid] = event.token | ||
| # Correct the token if it doesn't match what we expect for this SID | ||
| expected_token = self.sid_to_token.get(sid) | ||
| if expected_token and event.token != expected_token: | ||
| # Create new event with corrected token since Event is frozen | ||
| from dataclasses import replace | ||
|
|
||
| event = replace(event, token=expected_token) | ||
|
|
||
| # Get the event environment. | ||
| if self.app.sio is None: | ||
|
|
@@ -2100,3 +2138,17 @@ async def on_ping(self, sid: str): | |
| """ | ||
| # Emit the test event. | ||
| await self.emit(str(constants.SocketEvent.PING), "pong", to=sid) | ||
|
|
||
| async def link_token_to_sid(self, sid: str, token: str): | ||
| """Link a token to a session id. | ||
|
|
||
| Args: | ||
| sid: The Socket.IO session id. | ||
| token: The client token. | ||
| """ | ||
| # Use TokenManager for duplicate detection and Redis support | ||
| new_token = await self._token_manager.link_token_to_sid(token, sid) | ||
|
|
||
| if new_token: | ||
| # Duplicate detected, emit new token to client | ||
| await self.emit("new_token", new_token, to=sid) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
| """Token manager for handling client token to session ID mappings.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import uuid | ||
| from abc import ABC, abstractmethod | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from reflex.utils import console, prerequisites | ||
|
|
||
| if TYPE_CHECKING: | ||
| from redis.asyncio import Redis | ||
|
|
||
|
|
||
| def _get_new_token() -> str: | ||
| """Generate a new unique token. | ||
|
|
||
| Returns: | ||
| A new UUID4 token string. | ||
| """ | ||
| return str(uuid.uuid4()) | ||
|
|
||
|
|
||
| class TokenManager(ABC): | ||
| """Abstract base class for managing client token to session ID mappings.""" | ||
|
|
||
| def __init__(self): | ||
| """Initialize the token manager with local dictionaries.""" | ||
| # Keep a mapping between socket ID and client token. | ||
| self.token_to_sid: dict[str, str] = {} | ||
| # Keep a mapping between client token and socket ID. | ||
| self.sid_to_token: dict[str, str] = {} | ||
|
|
||
| @abstractmethod | ||
| async def link_token_to_sid(self, token: str, sid: str) -> str | None: | ||
| """Link a token to a session ID. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
| sid: The Socket.IO session ID. | ||
|
|
||
| Returns: | ||
| New token if duplicate detected and new token generated, None otherwise. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| async def disconnect_token(self, token: str, sid: str) -> None: | ||
| """Clean up token mapping when client disconnects. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
| sid: The Socket.IO session ID. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def create(cls) -> TokenManager: | ||
| """Factory method to create appropriate TokenManager implementation. | ||
|
|
||
| Returns: | ||
| RedisTokenManager if Redis is available, LocalTokenManager otherwise. | ||
| """ | ||
| if prerequisites.check_redis_used(): | ||
| redis_client = prerequisites.get_redis() | ||
| if redis_client is not None: | ||
| return RedisTokenManager(redis_client) | ||
|
|
||
| return LocalTokenManager() | ||
|
|
||
|
|
||
| class LocalTokenManager(TokenManager): | ||
| """Token manager using local in-memory dictionaries (single worker).""" | ||
|
|
||
| def __init__(self): | ||
| """Initialize the local token manager.""" | ||
| super().__init__() | ||
|
|
||
| async def link_token_to_sid(self, token: str, sid: str) -> str | None: | ||
| """Link a token to a session ID. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
| sid: The Socket.IO session ID. | ||
|
|
||
| Returns: | ||
| New token if duplicate detected and new token generated, None otherwise. | ||
| """ | ||
| # Check if token is already mapped to a different SID (duplicate tab) | ||
| if token in self.token_to_sid and sid != self.token_to_sid.get(token): | ||
| new_token = _get_new_token() | ||
| self.token_to_sid[new_token] = sid | ||
| self.sid_to_token[sid] = new_token | ||
| return new_token | ||
|
|
||
| # Normal case - link token to SID | ||
| self.token_to_sid[token] = sid | ||
| self.sid_to_token[sid] = token | ||
| return None | ||
|
|
||
| async def disconnect_token(self, token: str, sid: str) -> None: | ||
| """Clean up token mapping when client disconnects. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
| sid: The Socket.IO session ID. | ||
| """ | ||
| # Clean up both mappings | ||
| self.token_to_sid.pop(token, None) | ||
| self.sid_to_token.pop(sid, None) | ||
|
|
||
|
|
||
| class RedisTokenManager(LocalTokenManager): | ||
| """Token manager using Redis for distributed multi-worker support. | ||
|
|
||
| Inherits local dict logic from LocalTokenManager and adds Redis layer | ||
| for cross-worker duplicate detection. | ||
| """ | ||
|
|
||
| def __init__(self, redis: Redis): | ||
| """Initialize the Redis token manager. | ||
|
|
||
| Args: | ||
| redis: The Redis client instance. | ||
| """ | ||
| # Initialize parent's local dicts | ||
| super().__init__() | ||
|
|
||
| self.redis = redis | ||
|
|
||
| # Get token expiration from config (default 1 hour) | ||
| from reflex.config import get_config | ||
|
|
||
| config = get_config() | ||
| self.token_expiration = config.redis_token_expiration | ||
|
|
||
| def _get_redis_key(self, token: str) -> str: | ||
| """Get Redis key for token mapping. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
|
|
||
| Returns: | ||
| Redis key following Reflex conventions: {token}_sid | ||
| """ | ||
| return f"{token}_sid" | ||
|
|
||
| async def link_token_to_sid(self, token: str, sid: str) -> str | None: | ||
| """Link a token to a session ID with Redis-based duplicate detection. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
| sid: The Socket.IO session ID. | ||
|
|
||
| Returns: | ||
| New token if duplicate detected and new token generated, None otherwise. | ||
| """ | ||
| # Fast local check first (handles reconnections) | ||
| if token in self.token_to_sid and self.token_to_sid[token] == sid: | ||
| return None # Same token, same SID = reconnection, no Redis check needed | ||
|
|
||
| # Check Redis for cross-worker duplicates | ||
| redis_key = self._get_redis_key(token) | ||
|
|
||
| try: | ||
| token_exists_in_redis = await self.redis.exists(redis_key) | ||
| except Exception as e: | ||
| console.error(f"Redis error checking token existence: {e}") | ||
| return await super().link_token_to_sid(token, sid) | ||
|
|
||
| if token_exists_in_redis: | ||
| # Duplicate exists somewhere - generate new token | ||
| new_token = _get_new_token() | ||
| new_redis_key = self._get_redis_key(new_token) | ||
|
|
||
| try: | ||
| # Store in Redis | ||
| await self.redis.set(new_redis_key, "1", ex=self.token_expiration) | ||
| except Exception as e: | ||
| console.error(f"Redis error storing new token: {e}") | ||
| # Still update local dicts and continue | ||
|
|
||
| # Store in local dicts (always do this) | ||
| self.token_to_sid[new_token] = sid | ||
| self.sid_to_token[sid] = new_token | ||
| return new_token | ||
|
|
||
| # Normal case - store in both Redis and local dicts | ||
| try: | ||
| await self.redis.set(redis_key, "1", ex=self.token_expiration) | ||
| except Exception as e: | ||
| console.error(f"Redis error storing token: {e}") | ||
| # Continue with local storage | ||
|
|
||
| # Store in local dicts (always do this) | ||
| self.token_to_sid[token] = sid | ||
| self.sid_to_token[sid] = token | ||
| return None | ||
|
|
||
| async def disconnect_token(self, token: str, sid: str) -> None: | ||
| """Clean up token mapping when client disconnects. | ||
|
|
||
| Args: | ||
| token: The client token. | ||
| sid: The Socket.IO session ID. | ||
| """ | ||
| # Only clean up if we own it locally (fast ownership check) | ||
| if self.token_to_sid.get(token) == sid: | ||
| # Clean up Redis | ||
| redis_key = self._get_redis_key(token) | ||
| try: | ||
| await self.redis.delete(redis_key) | ||
| except Exception as e: | ||
| console.error(f"Redis error deleting token: {e}") | ||
|
|
||
| # Clean up local dicts (always do this) | ||
| await super().disconnect_token(token, sid) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider putting the
TokenManageronAppas a real API into this data.down the road we could even expose helper iterators on this new
TokenManagerto make cancelling background tasks easier, likesession_is_connectedandtoken_is_connectedwhich would save the passed sid/token andyield Trueas long as these values were still being tracked.Like so
lot of possibilities for an asyncio.Event API for getting notified when a session or client disconnects, which people have asked for