forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathredis.py
More file actions
126 lines (104 loc) · 4.85 KB
/
redis.py
File metadata and controls
126 lines (104 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import logging
from uuid import UUID
import mcp.types as types
try:
import redis.asyncio as redis # type: ignore[import]
except ImportError:
raise ImportError(
"Redis support requires the 'redis' package. "
"Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'"
)
logger = logging.getLogger(__name__)
class RedisMessageQueue:
"""Redis implementation of the MessageQueue interface.
This implementation uses Redis lists to store messages for each session.
Redis provides persistence and allows multiple servers to share the same queue.
"""
def __init__(
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:queue:"
) -> None:
"""Initialize Redis message queue.
Args:
redis_url: Redis connection string
prefix: Key prefix for Redis keys to avoid collisions
"""
self._redis = redis.Redis.from_url(redis_url, decode_responses=True) # type: ignore[attr-defined]
self._prefix = prefix
self._active_sessions_key = f"{prefix}active_sessions"
logger.debug(f"Initialized Redis message queue with URL: {redis_url}")
def _session_queue_key(self, session_id: UUID) -> str:
"""Get the Redis key for a session's message queue."""
return f"{self._prefix}session:{session_id.hex}"
async def add_message(
self, session_id: UUID, message: types.JSONRPCMessage | Exception
) -> bool:
"""Add a message to the queue for the specified session."""
# Check if session exists
if not await self.session_exists(session_id):
logger.warning(f"Message received for unknown session {session_id}")
return False
# Serialize the message
if isinstance(message, Exception):
# For exceptions, store them as special format
data = json.dumps(
{
"_exception": True,
"type": type(message).__name__,
"message": str(message),
}
)
else:
data = message.model_dump_json(by_alias=True, exclude_none=True)
# Push to the right side of the list (queue)
await self._redis.rpush(self._session_queue_key(session_id), data) # type: ignore[attr-defined]
logger.debug(f"Added message to Redis queue for session {session_id}")
return True
async def get_message(
self, session_id: UUID, timeout: float = 0.1
) -> types.JSONRPCMessage | Exception | None:
"""Get the next message for the specified session."""
# Check if session exists
if not await self.session_exists(session_id):
return None
# Pop from the left side of the list (queue)
# Use BLPOP with timeout to avoid busy waiting
result = await self._redis.blpop([self._session_queue_key(session_id)], timeout) # type: ignore[attr-defined]
if not result:
return None
# result is a tuple of (key, value)
_, data = result # type: ignore[misc]
# Deserialize the message
json_data = json.loads(data) # type: ignore[arg-type]
# Check if it's an exception
if isinstance(json_data, dict):
exception_dict: dict[str, object] = json_data
if exception_dict.get("_exception", False):
return Exception(
f"{exception_dict['type']}: {exception_dict['message']}"
)
# Regular message
try:
return types.JSONRPCMessage.model_validate_json(data) # type: ignore[arg-type]
except Exception as e:
logger.error(f"Failed to deserialize message: {e}")
return None
async def register_session(self, session_id: UUID) -> None:
"""Register a new session with the queue."""
# Add session ID to the set of active sessions
await self._redis.sadd(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
logger.debug(f"Registered session {session_id} in Redis")
async def unregister_session(self, session_id: UUID) -> None:
"""Unregister a session when it's closed."""
# Remove session ID from active sessions
await self._redis.srem(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
# Delete the session's message queue
await self._redis.delete(self._session_queue_key(session_id)) # type: ignore[attr-defined]
logger.debug(f"Unregistered session {session_id} from Redis")
async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists."""
# Explicitly annotate the result as bool to help the type checker
result = bool(
await self._redis.sismember(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
)
return result