-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathmemory.py
More file actions
122 lines (106 loc) · 4.86 KB
/
memory.py
File metadata and controls
122 lines (106 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
In-memory transports
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ElicitCompleteFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
ProgressNotificationFnT,
PromptListChangedFnT,
ResourceListChangedFnT,
ResourceUpdatedFnT,
SamplingFnT,
ToolListChangedFnT,
)
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.shared.message import SessionMessage
MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]:
"""
Creates a pair of bidirectional memory streams for client-server communication.
Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)
async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server[Any] | FastMCP,
read_timeout_seconds: float | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
progress_notification_callback: ProgressNotificationFnT | None = None,
resource_updated_callback: ResourceUpdatedFnT | None = None,
resource_list_changed_callback: ResourceListChangedFnT | None = None,
tool_list_changed_callback: ToolListChangedFnT | None = None,
prompt_list_changed_callback: PromptListChangedFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
elicit_complete_callback: ElicitCompleteFnT | None = None,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
# TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport",
# and we should expose a method in the `FastMCP` so we don't access a private attribute.
if isinstance(server, FastMCP): # pragma: no cover
server = server._mcp_server # type: ignore[reportPrivateUsage]
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)
try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
progress_notification_callback=progress_notification_callback,
resource_updated_callback=resource_updated_callback,
resource_list_changed_callback=resource_list_changed_callback,
tool_list_changed_callback=tool_list_changed_callback,
prompt_list_changed_callback=prompt_list_changed_callback,
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
elicit_complete_callback=elicit_complete_callback,
) as client_session:
await client_session.initialize()
yield client_session
finally: # pragma: no cover
tg.cancel_scope.cancel()