1010import re
1111import sys
1212from abc import abstractmethod
13- from collections .abc import Callable , Collection , Coroutine , Sequence
13+ from collections .abc import Callable , Collection , Coroutine , Mapping , Sequence
1414from contextlib import AsyncExitStack , _AsyncGeneratorContextManager # type: ignore
1515from datetime import timedelta
1616from functools import partial
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
142142 return meta
143143
144144
145+ def _url_origin (url : Any ) -> tuple [str , str , int | None ]:
146+ port = url .port
147+ if port is None :
148+ port = 443 if url .scheme == "https" else 80 if url .scheme == "http" else None
149+ return (url .scheme , url .host or "" , port )
150+
151+
145152def streamable_http_client (* args : Any , ** kwargs : Any ) -> _AsyncGeneratorContextManager [Any , None ]:
146153 """Lazily import the MCP streamable HTTP transport."""
147154 try :
@@ -255,6 +262,7 @@ def __init__(
255262 self ._exit_stack = AsyncExitStack ()
256263 self ._lifecycle_lock = asyncio .Lock ()
257264 self ._lifecycle_request_lock = asyncio .Lock ()
265+ self ._function_load_lock = asyncio .Lock ()
258266 self ._lifecycle_queue : asyncio .Queue [tuple [str , bool , bool , asyncio .Future [None ]]] | None = None
259267 self ._lifecycle_owner_task : asyncio .Task [None ] | None = None
260268 self .session = session
@@ -655,6 +663,11 @@ async def _safe_close_exit_stack(self) -> None:
655663 raise
656664 except asyncio .CancelledError :
657665 logger .warning ("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled." )
666+ except Exception as e :
667+ if type (e ).__name__ == "ExceptionGroup" :
668+ logger .warning ("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s" , e )
669+ else :
670+ raise
658671
659672 async def _close_and_check_cancelled (self , ex : BaseException ) -> bool :
660673 """Close the exit stack and return True if *ex* is a genuine task cancellation.
@@ -1018,6 +1031,10 @@ async def load_prompts(self) -> None:
10181031 Raises:
10191032 ToolExecutionException: If the MCP server is not connected.
10201033 """
1034+ async with self ._function_load_lock :
1035+ await self ._load_prompts_locked ()
1036+
1037+ async def _load_prompts_locked (self ) -> None :
10211038 from anyio import ClosedResourceError
10221039 from mcp import types
10231040
@@ -1100,6 +1117,10 @@ async def load_tools(self) -> None:
11001117 Raises:
11011118 ToolExecutionException: If the MCP server is not connected.
11021119 """
1120+ async with self ._function_load_lock :
1121+ await self ._load_tools_locked ()
1122+
1123+ async def _load_tools_locked (self ) -> None :
11031124 from anyio import ClosedResourceError
11041125 from mcp import types
11051126
@@ -1109,7 +1130,7 @@ async def load_tools(self) -> None:
11091130
11101131 # Track existing function names to prevent duplicates
11111132 existing_names = {func .name for func in self ._functions }
1112- self . _tool_call_meta_by_name . clear ()
1133+ tool_call_meta_by_name : dict [ str , dict [ str , Any ]] = {}
11131134
11141135 params : types .PaginatedRequestParams | None = None
11151136 while True :
@@ -1145,7 +1166,7 @@ async def load_tools(self) -> None:
11451166
11461167 for tool in tool_list .tools :
11471168 if tool .meta is not None :
1148- self . _tool_call_meta_by_name [tool .name ] = dict (tool .meta )
1169+ tool_call_meta_by_name [tool .name ] = dict (tool .meta )
11491170
11501171 normalized_name = _normalize_mcp_name (tool .name )
11511172 local_name = _build_prefixed_mcp_name (normalized_name , self .tool_name_prefix )
@@ -1194,6 +1215,8 @@ async def _call_tool_with_runtime_kwargs(
11941215 break
11951216 params = types .PaginatedRequestParams (cursor = tool_list .nextCursor )
11961217
1218+ self ._tool_call_meta_by_name = tool_call_meta_by_name
1219+
11971220 async def _close_on_owner (self ) -> None :
11981221 # Cancel any pending reload tasks before tearing down the session.
11991222 tasks = list (self ._pending_reload_tasks )
@@ -1276,7 +1299,11 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
12761299 tool_name: The name of the tool to call.
12771300
12781301 Keyword Args:
1279- kwargs: Arguments to pass to the tool.
1302+ _meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
1303+ ``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
1304+ User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
1305+ non-conflicting keys.
1306+ kwargs: Remaining arguments to pass to the tool.
12801307
12811308 Returns:
12821309 A list of Content items representing the tool output. The default
@@ -1294,6 +1321,19 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
12941321 raise ToolExecutionException (
12951322 "Tools are not loaded for this server, please set load_tools=True in the constructor."
12961323 )
1324+
1325+ raw_user_meta : object | None = kwargs .get ("_meta" )
1326+ user_meta : dict [str , Any ] | None = None
1327+ if raw_user_meta is not None and not isinstance (raw_user_meta , dict ):
1328+ raise ToolExecutionException ("MCP tool metadata provided via _meta must be a dict." )
1329+ if isinstance (raw_user_meta , dict ):
1330+ raw_user_meta_dict = cast (Mapping [object , object ], raw_user_meta )
1331+ user_meta = {}
1332+ for key , value in raw_user_meta_dict .items ():
1333+ if not isinstance (key , str ):
1334+ raise ToolExecutionException ("MCP tool metadata provided via _meta must use string keys." )
1335+ user_meta [key ] = value
1336+
12971337 # Filter out framework kwargs that cannot be serialized by the MCP SDK.
12981338 # These are internal objects passed through the function invocation pipeline
12991339 # that should not be forwarded to external MCP servers.
@@ -1313,12 +1353,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
13131353 "conversation_id" ,
13141354 "options" ,
13151355 "response_format" ,
1356+ "_meta" ,
13161357 }
13171358 }
13181359
13191360 # Some MCP proxies require their tools/list metadata to be echoed on tools/call.
13201361 tool_meta = self ._tool_call_meta_by_name .get (tool_name )
1321- meta = _inject_otel_into_mcp_meta (dict (tool_meta ) if tool_meta is not None else None )
1362+ request_meta = dict (tool_meta ) if tool_meta is not None else None
1363+ if user_meta is not None :
1364+ request_meta = {** (request_meta or {}), ** user_meta }
1365+ meta = _inject_otel_into_mcp_meta (request_meta )
13221366
13231367 parser = self .parse_tool_results or self ._parse_tool_result_from_mcp
13241368 # Try the operation, reconnecting once if the connection is closed
@@ -1336,28 +1380,33 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
13361380 return parser (result )
13371381 except ToolExecutionException :
13381382 raise
1339- except ClosedResourceError as cl_ex :
1383+ except (ClosedResourceError , McpError ) as call_ex :
1384+ is_session_terminated = (
1385+ isinstance (call_ex , McpError ) and "session terminated" in call_ex .error .message .lower ()
1386+ )
1387+ is_connection_lost = isinstance (call_ex , ClosedResourceError ) or is_session_terminated
1388+ if not is_connection_lost :
1389+ error_message = call_ex .error .message if isinstance (call_ex , McpError ) else str (call_ex )
1390+ raise ToolExecutionException (error_message , inner_exception = call_ex ) from call_ex
1391+
13401392 if attempt == 0 :
1341- # First attempt failed, try reconnecting
1342- logger .info ("MCP connection closed unexpectedly. Reconnecting..." )
1393+ # First attempt failed, try reconnecting.
1394+ logger .info ("MCP connection closed or terminated unexpectedly. Reconnecting..." )
13431395 try :
13441396 await self .connect (reset = True )
1345- continue # Retry the operation
1397+ continue
13461398 except Exception as reconn_ex :
13471399 raise ToolExecutionException (
13481400 "Failed to reconnect to MCP server." ,
13491401 inner_exception = reconn_ex ,
13501402 ) from reconn_ex
1351- else :
1352- # Second attempt also failed, give up
1353- logger .error (f"MCP connection closed unexpectedly after reconnection: { cl_ex } " )
1354- raise ToolExecutionException (
1355- f"Failed to call tool '{ tool_name } ' - connection lost." ,
1356- inner_exception = cl_ex ,
1357- ) from cl_ex
1358- except McpError as mcp_exc :
1359- error_message = mcp_exc .error .message
1360- raise ToolExecutionException (error_message , inner_exception = mcp_exc ) from mcp_exc
1403+
1404+ # Second attempt also failed, give up.
1405+ logger .error ("MCP connection closed unexpectedly after reconnection: %s" , call_ex )
1406+ raise ToolExecutionException (
1407+ f"Failed to call tool '{ tool_name } ' - connection lost." ,
1408+ inner_exception = call_ex ,
1409+ ) from call_ex
13611410 except Exception as ex :
13621411 raise ToolExecutionException (f"Failed to call tool '{ tool_name } '." , inner_exception = ex ) from ex
13631412 raise ToolExecutionException (f"Failed to call tool '{ tool_name } ' after retries." )
@@ -1718,10 +1767,11 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
17181767 Returns:
17191768 An async context manager for the streamable HTTP client transport.
17201769 """
1721- from httpx import AsyncClient , Request , Timeout
1770+ from httpx import URL , AsyncClient , Request , Timeout
17221771
17231772 http_client = self ._httpx_client
17241773 if self ._header_provider is not None :
1774+ target_origin = _url_origin (URL (self .url ))
17251775 if http_client is None :
17261776 http_client = AsyncClient (
17271777 follow_redirects = True ,
@@ -1732,6 +1782,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
17321782 if not hasattr (self , "_inject_headers_hook" ):
17331783
17341784 async def _inject_headers (request : Request ) -> None : # noqa: RUF029
1785+ if _url_origin (request .url ) != target_origin :
1786+ return
17351787 headers = _mcp_call_headers .get ({})
17361788 for key , value in headers .items ():
17371789 request .headers [key ] = value
0 commit comments