-
Notifications
You must be signed in to change notification settings - Fork 860
Expand file tree
/
Copy pathmcp_client.py
More file actions
551 lines (454 loc) · 24.8 KB
/
mcp_client.py
File metadata and controls
551 lines (454 loc) · 24.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
"""Model Context Protocol (MCP) server connection management module.
This module provides the MCPClient class which handles connections to MCP servers.
It manages the lifecycle of MCP connections, including initialization, tool discovery,
tool invocation, and proper cleanup of resources. The connection runs in a background
thread to avoid blocking the main application thread while maintaining communication
with the MCP service.
"""
import asyncio
import base64
import logging
import threading
import uuid
from asyncio import AbstractEventLoop
from concurrent import futures
from datetime import timedelta
from types import TracebackType
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar, Union, cast
from mcp import ClientSession, ListToolsResult
from mcp.types import CallToolResult as MCPCallToolResult
from mcp.types import GetPromptResult, ListPromptsResult
from mcp.types import ImageContent as MCPImageContent
from mcp.types import TextContent as MCPTextContent
from mcp.types import EmbeddedResource as MCPEmbeddedResource
from ...types import PaginatedList
from ...types.exceptions import MCPClientInitializationError
from ...types.media import ImageFormat
from ...types.tools import ToolResultContent, ToolResultStatus
from .mcp_agent_tool import MCPAgentTool
from .mcp_instrumentation import mcp_instrumentation
from .mcp_types import MCPToolResult, MCPTransport
logger = logging.getLogger(__name__)
T = TypeVar("T")
MIME_TO_FORMAT: Dict[str, ImageFormat] = {
"image/jpeg": "jpeg",
"image/jpg": "jpeg",
"image/png": "png",
"image/gif": "gif",
"image/webp": "webp",
}
CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = (
"the client session is not running. Ensure the agent is used within "
"the MCP client context manager. For more information see: "
"https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror"
)
class MCPClient:
"""Represents a connection to a Model Context Protocol (MCP) server.
This class implements a context manager pattern for efficient connection management,
allowing reuse of the same connection for multiple tool calls to reduce latency.
It handles the creation, initialization, and cleanup of MCP connections.
The connection runs in a background thread to avoid blocking the main application thread
while maintaining communication with the MCP service. When structured content is available
from MCP tools, it will be returned as the last item in the content array of the ToolResult.
"""
def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30):
"""Initialize a new MCP Server connection.
Args:
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple
startup_timeout: Timeout after which MCP server initialization should be cancelled
Defaults to 30.
"""
self._startup_timeout = startup_timeout
mcp_instrumentation()
self._session_id = uuid.uuid4()
self._log_debug_with_thread("initializing MCPClient connection")
# Main thread blocks until future completesock
self._init_future: futures.Future[None] = futures.Future()
# Do not want to block other threads while close event is false
self._close_event = asyncio.Event()
self._transport_callable = transport_callable
self._background_thread: threading.Thread | None = None
self._background_thread_session: ClientSession | None = None
self._background_thread_event_loop: AbstractEventLoop | None = None
def __enter__(self) -> "MCPClient":
"""Context manager entry point which initializes the MCP server connection.
TODO: Refactor to lazy initialization pattern following idiomatic Python.
Heavy work in __enter__ is non-idiomatic - should move connection logic to first method call instead.
"""
return self.start()
def __exit__(self, exc_type: BaseException, exc_val: BaseException, exc_tb: TracebackType) -> None:
"""Context manager exit point that cleans up resources."""
self.stop(exc_type, exc_val, exc_tb)
def start(self) -> "MCPClient":
"""Starts the background thread and waits for initialization.
This method starts the background thread that manages the MCP connection
and blocks until the connection is ready or times out.
Returns:
self: The MCPClient instance
Raises:
Exception: If the MCP connection fails to initialize within the timeout period
"""
if self._is_session_active():
raise MCPClientInitializationError("the client session is currently running")
self._log_debug_with_thread("entering MCPClient context")
self._background_thread = threading.Thread(target=self._background_task, args=[], daemon=True)
self._background_thread.start()
self._log_debug_with_thread("background thread started, waiting for ready event")
try:
# Blocking main thread until session is initialized in other thread or if the thread stops
self._init_future.result(timeout=self._startup_timeout)
self._log_debug_with_thread("the client initialization was successful")
except futures.TimeoutError as e:
logger.exception("client initialization timed out")
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
self.stop(None, None, None)
raise MCPClientInitializationError(
f"background thread did not start in {self._startup_timeout} seconds"
) from e
except Exception as e:
logger.exception("client failed to initialize")
# Pass None for exc_type, exc_val, exc_tb since this isn't a context manager exit
self.stop(None, None, None)
raise MCPClientInitializationError("the client initialization failed") from e
return self
def stop(
self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
"""Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources.
This method is defensive and can handle partial initialization states that may occur
if start() fails partway through initialization.
Resources to cleanup:
- _background_thread: Thread running the async event loop
- _background_thread_session: MCP ClientSession (auto-closed by context manager)
- _background_thread_event_loop: AsyncIO event loop in background thread
- _close_event: AsyncIO event to signal thread shutdown
- _init_future: Future for initialization synchronization
Cleanup order:
1. Signal close event to background thread (if session initialized)
2. Wait for background thread to complete
3. Reset all state for reuse
Args:
exc_type: Exception type if an exception was raised in the context
exc_val: Exception value if an exception was raised in the context
exc_tb: Exception traceback if an exception was raised in the context
"""
self._log_debug_with_thread("exiting MCPClient context")
# Only try to signal close event if we have a background thread
if self._background_thread is not None:
# Signal close event if event loop exists
if self._background_thread_event_loop is not None:
async def _set_close_event() -> None:
self._close_event.set()
# Not calling _invoke_on_background_thread since the session does not need to exist
# we only need the thread and event loop to exist.
asyncio.run_coroutine_threadsafe(coro=_set_close_event(), loop=self._background_thread_event_loop)
self._log_debug_with_thread("waiting for background thread to join")
self._background_thread.join()
self._log_debug_with_thread("background thread is closed, MCPClient context exited")
# Reset fields to allow instance reuse
self._init_future = futures.Future()
self._close_event = asyncio.Event()
self._background_thread = None
self._background_thread_session = None
self._background_thread_event_loop = None
self._session_id = uuid.uuid4()
def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]:
"""Synchronously retrieves the list of available tools from the MCP server.
This method calls the asynchronous list_tools method on the MCP session
and adapts the returned tools to the AgentTool interface.
Returns:
List[AgentTool]: A list of available tools adapted to the AgentTool interface
"""
self._log_debug_with_thread("listing MCP tools synchronously")
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
async def _list_tools_async() -> ListToolsResult:
return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token)
list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools))
mcp_tools = [MCPAgentTool(tool, self) for tool in list_tools_response.tools]
self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools))
return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor)
def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult:
"""Synchronously retrieves the list of available prompts from the MCP server.
This method calls the asynchronous list_prompts method on the MCP session
and returns the raw ListPromptsResult with pagination support.
Args:
pagination_token: Optional token for pagination
Returns:
ListPromptsResult: The raw MCP response containing prompts and pagination info
"""
self._log_debug_with_thread("listing MCP prompts synchronously")
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
async def _list_prompts_async() -> ListPromptsResult:
return await cast(ClientSession, self._background_thread_session).list_prompts(cursor=pagination_token)
list_prompts_result: ListPromptsResult = self._invoke_on_background_thread(_list_prompts_async()).result()
self._log_debug_with_thread("received %d prompts from MCP server", len(list_prompts_result.prompts))
for prompt in list_prompts_result.prompts:
self._log_debug_with_thread(prompt.name)
return list_prompts_result
def get_prompt_sync(self, prompt_id: str, args: dict[str, Any]) -> GetPromptResult:
"""Synchronously retrieves a prompt from the MCP server.
Args:
prompt_id: The ID of the prompt to retrieve
args: Optional arguments to pass to the prompt
Returns:
GetPromptResult: The prompt response from the MCP server
"""
self._log_debug_with_thread("getting MCP prompt synchronously")
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
async def _get_prompt_async() -> GetPromptResult:
return await cast(ClientSession, self._background_thread_session).get_prompt(prompt_id, arguments=args)
get_prompt_result: GetPromptResult = self._invoke_on_background_thread(_get_prompt_async()).result()
self._log_debug_with_thread("received prompt from MCP server")
return get_prompt_result
def call_tool_sync(
self,
tool_use_id: str,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> MCPToolResult:
"""Synchronously calls a tool on the MCP server.
This method calls the asynchronous call_tool method on the MCP session
and converts the result to the ToolResult format. If the MCP tool returns
structured content, it will be included as the last item in the content array
of the returned ToolResult.
Args:
tool_use_id: Unique identifier for this tool use
name: Name of the tool to call
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call
Returns:
MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
async def _call_tool_async() -> MCPCallToolResult:
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds
)
try:
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result()
return self._handle_tool_result(tool_use_id, call_tool_result)
except Exception as e:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)
async def call_tool_async(
self,
tool_use_id: str,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> MCPToolResult:
"""Asynchronously calls a tool on the MCP server.
This method calls the asynchronous call_tool method on the MCP session
and converts the result to the MCPToolResult format.
Args:
tool_use_id: Unique identifier for this tool use
name: Name of the tool to call
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call
Returns:
MCPToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
async def _call_tool_async() -> MCPCallToolResult:
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds
)
try:
future = self._invoke_on_background_thread(_call_tool_async())
call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future)
return self._handle_tool_result(tool_use_id, call_tool_result)
except Exception as e:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)
def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> MCPToolResult:
"""Create error ToolResult with consistent logging."""
return MCPToolResult(
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(exception)}"}],
)
def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> MCPToolResult:
"""Maps MCP tool result to the agent's MCPToolResult format.
This method processes the content from the MCP tool call result and converts it to the format
expected by the framework.
Args:
tool_use_id: Unique identifier for this tool use
call_tool_result: The result from the MCP tool call
Returns:
MCPToolResult: The converted tool result
"""
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))
# Build a typed list of ToolResultContent. Use a clearer local name to avoid shadowing
# and annotate the result for mypy so it knows the intended element type.
mapped_contents: list[ToolResultContent] = [
mc
for content in call_tool_result.content
if (mc := self._map_mcp_content_to_tool_result_content(content)) is not None
]
status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
result = MCPToolResult(
status=status,
toolUseId=tool_use_id,
content=mapped_contents,
)
if call_tool_result.structuredContent:
result["structuredContent"] = call_tool_result.structuredContent
return result
async def _async_background_thread(self) -> None:
"""Asynchronous method that runs in the background thread to manage the MCP connection.
This method establishes the transport connection, creates and initializes the MCP session,
signals readiness to the main thread, and waits for a close signal.
"""
self._log_debug_with_thread("starting async background thread for MCP connection")
try:
async with self._transport_callable() as (read_stream, write_stream, *_):
self._log_debug_with_thread("transport connection established")
async with ClientSession(read_stream, write_stream) as session:
self._log_debug_with_thread("initializing MCP session")
await session.initialize()
self._log_debug_with_thread("session initialized successfully")
# Store the session for use while we await the close event
self._background_thread_session = session
# Signal that the session has been created and is ready for use
self._init_future.set_result(None)
self._log_debug_with_thread("waiting for close signal")
# Keep background thread running until signaled to close.
# Thread is not blocked as this is an asyncio.Event not a threading.Event
await self._close_event.wait()
self._log_debug_with_thread("close signal received")
except Exception as e:
# If we encounter an exception and the future is still running,
# it means it was encountered during the initialization phase.
if not self._init_future.done():
self._init_future.set_exception(e)
else:
self._log_debug_with_thread(
"encountered exception on background thread after initialization %s", str(e)
)
def _background_task(self) -> None:
"""Sets up and runs the event loop in the background thread.
This method creates a new event loop for the background thread,
sets it as the current event loop, and runs the async_background_thread
coroutine until completion. In this case "until completion" means until the _close_event is set.
This allows for a long-running event loop.
"""
self._log_debug_with_thread("setting up background task event loop")
self._background_thread_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._background_thread_event_loop)
self._background_thread_event_loop.run_until_complete(self._async_background_thread())
def _map_mcp_content_to_tool_result_content(
self,
content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any,
) -> Union[ToolResultContent, None]:
"""Maps MCP content types to tool result content types.
This method converts MCP-specific content types to the generic
ToolResultContent format used by the agent framework.
Args:
content: The MCP content to convert
Returns:
ToolResultContent or None: The converted content, or None if the content type is not supported
"""
if isinstance(content, MCPTextContent):
self._log_debug_with_thread("mapping MCP text content")
return {"text": content.text}
elif isinstance(content, MCPImageContent):
self._log_debug_with_thread("mapping MCP image content with mime type: %s", content.mimeType)
return {
"image": {
"format": MIME_TO_FORMAT[content.mimeType],
"source": {"bytes": base64.b64decode(content.data)},
}
}
elif isinstance(content, MCPEmbeddedResource):
self._log_debug_with_thread("mapping MCP embedded resource content")
resource = getattr(content, "resource", None)
if resource is None:
self._log_debug_with_thread("embedded resource has no 'resource' field - dropping")
return None
text_val = getattr(resource, "text", None)
if text_val:
return {"text": text_val}
blob_val = getattr(resource, "blob", None)
mime_type = getattr(resource, "mimeType", None)
if blob_val is not None:
# blob is a base64 string in current mcp schema
raw_bytes: Optional[bytes]
try:
if isinstance(blob_val, (bytes, bytearray)):
raw_bytes = bytes(blob_val)
elif isinstance(blob_val, str):
raw_bytes = base64.b64decode(blob_val)
else:
raw_bytes = None
except Exception:
raw_bytes = None
if raw_bytes is None:
self._log_debug_with_thread("embedded resource blob could not be decoded - dropping")
return None
def _is_textual(mt: Optional[str]) -> bool:
if not mt:
return False
if mt.startswith("text/"):
return True
textual = (
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/xhtml+xml",
)
if mt in textual or mt.endswith("+json") or mt.endswith("+xml"):
return True
return False
if _is_textual(mime_type):
try:
return {"text": raw_bytes.decode("utf-8", errors="replace")}
except Exception:
pass
if mime_type in MIME_TO_FORMAT:
return {
"image": {
"format": MIME_TO_FORMAT[mime_type],
"source": {"bytes": raw_bytes},
}
}
self._log_debug_with_thread("embedded resource blob with non-textual/unknown mimeType - dropping")
return None
# Handle URI-only resources
uri = getattr(resource, "uri", None)
if uri:
return {
"json": {
"uri": uri,
"mime_type": mime_type
}
}
# Make sure we return in all paths
self._log_debug_with_thread("embedded resource had no usable text/blob/uri; dropping")
return None
else:
self._log_debug_with_thread("unhandled content type: %s - dropping content", content.__class__.__name__)
return None
def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""Logger helper to help differentiate logs coming from MCPClient background thread."""
formatted_msg = msg % args if args else msg
logger.debug(
"[Thread: %s, Session: %s] %s", threading.current_thread().name, self._session_id, formatted_msg, **kwargs
)
def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]:
if self._background_thread_session is None or self._background_thread_event_loop is None:
raise MCPClientInitializationError("the client session was not initialized")
return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop)
def _is_session_active(self) -> bool:
return self._background_thread is not None and self._background_thread.is_alive()