Skip to content

Commit e8ff541

Browse files
Python: consolidate MCP reliability fixes (#6145)
* Python: consolidate MCP reliability fixes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix MCP cleanup and metadata typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Satisfy MCP metadata mypy typing Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix Pyright metadata mapping type Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent d2d5384 commit e8ff541

2 files changed

Lines changed: 278 additions & 20 deletions

File tree

python/packages/core/agent_framework/_mcp.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import re
1111
import sys
1212
from abc import abstractmethod
13-
from collections.abc import Callable, Collection, Coroutine, Sequence
13+
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
1414
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
1515
from datetime import timedelta
1616
from functools import partial
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
142142
return meta
143143

144144

145+
def _url_origin(url: Any) -> tuple[str, str, int | None]:
146+
port = url.port
147+
if port is None:
148+
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
149+
return (url.scheme, url.host or "", port)
150+
151+
145152
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
146153
"""Lazily import the MCP streamable HTTP transport."""
147154
try:
@@ -255,6 +262,7 @@ def __init__(
255262
self._exit_stack = AsyncExitStack()
256263
self._lifecycle_lock = asyncio.Lock()
257264
self._lifecycle_request_lock = asyncio.Lock()
265+
self._function_load_lock = asyncio.Lock()
258266
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
259267
self._lifecycle_owner_task: asyncio.Task[None] | None = None
260268
self.session = session
@@ -655,6 +663,11 @@ async def _safe_close_exit_stack(self) -> None:
655663
raise
656664
except asyncio.CancelledError:
657665
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
666+
except Exception as e:
667+
if type(e).__name__ == "ExceptionGroup":
668+
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
669+
else:
670+
raise
658671

659672
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
660673
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
@@ -1018,6 +1031,10 @@ async def load_prompts(self) -> None:
10181031
Raises:
10191032
ToolExecutionException: If the MCP server is not connected.
10201033
"""
1034+
async with self._function_load_lock:
1035+
await self._load_prompts_locked()
1036+
1037+
async def _load_prompts_locked(self) -> None:
10211038
from anyio import ClosedResourceError
10221039
from mcp import types
10231040

@@ -1100,6 +1117,10 @@ async def load_tools(self) -> None:
11001117
Raises:
11011118
ToolExecutionException: If the MCP server is not connected.
11021119
"""
1120+
async with self._function_load_lock:
1121+
await self._load_tools_locked()
1122+
1123+
async def _load_tools_locked(self) -> None:
11031124
from anyio import ClosedResourceError
11041125
from mcp import types
11051126

@@ -1109,7 +1130,7 @@ async def load_tools(self) -> None:
11091130

11101131
# Track existing function names to prevent duplicates
11111132
existing_names = {func.name for func in self._functions}
1112-
self._tool_call_meta_by_name.clear()
1133+
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
11131134

11141135
params: types.PaginatedRequestParams | None = None
11151136
while True:
@@ -1145,7 +1166,7 @@ async def load_tools(self) -> None:
11451166

11461167
for tool in tool_list.tools:
11471168
if tool.meta is not None:
1148-
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
1169+
tool_call_meta_by_name[tool.name] = dict(tool.meta)
11491170

11501171
normalized_name = _normalize_mcp_name(tool.name)
11511172
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
@@ -1194,6 +1215,8 @@ async def _call_tool_with_runtime_kwargs(
11941215
break
11951216
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
11961217

1218+
self._tool_call_meta_by_name = tool_call_meta_by_name
1219+
11971220
async def _close_on_owner(self) -> None:
11981221
# Cancel any pending reload tasks before tearing down the session.
11991222
tasks = list(self._pending_reload_tasks)
@@ -1276,7 +1299,11 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
12761299
tool_name: The name of the tool to call.
12771300
12781301
Keyword Args:
1279-
kwargs: Arguments to pass to the tool.
1302+
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
1303+
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
1304+
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
1305+
non-conflicting keys.
1306+
kwargs: Remaining arguments to pass to the tool.
12801307
12811308
Returns:
12821309
A list of Content items representing the tool output. The default
@@ -1294,6 +1321,19 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
12941321
raise ToolExecutionException(
12951322
"Tools are not loaded for this server, please set load_tools=True in the constructor."
12961323
)
1324+
1325+
raw_user_meta: object | None = kwargs.get("_meta")
1326+
user_meta: dict[str, Any] | None = None
1327+
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
1328+
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
1329+
if isinstance(raw_user_meta, dict):
1330+
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
1331+
user_meta = {}
1332+
for key, value in raw_user_meta_dict.items():
1333+
if not isinstance(key, str):
1334+
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
1335+
user_meta[key] = value
1336+
12971337
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
12981338
# These are internal objects passed through the function invocation pipeline
12991339
# that should not be forwarded to external MCP servers.
@@ -1313,12 +1353,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
13131353
"conversation_id",
13141354
"options",
13151355
"response_format",
1356+
"_meta",
13161357
}
13171358
}
13181359

13191360
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
13201361
tool_meta = self._tool_call_meta_by_name.get(tool_name)
1321-
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
1362+
request_meta = dict(tool_meta) if tool_meta is not None else None
1363+
if user_meta is not None:
1364+
request_meta = {**(request_meta or {}), **user_meta}
1365+
meta = _inject_otel_into_mcp_meta(request_meta)
13221366

13231367
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
13241368
# Try the operation, reconnecting once if the connection is closed
@@ -1336,28 +1380,33 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]:
13361380
return parser(result)
13371381
except ToolExecutionException:
13381382
raise
1339-
except ClosedResourceError as cl_ex:
1383+
except (ClosedResourceError, McpError) as call_ex:
1384+
is_session_terminated = (
1385+
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
1386+
)
1387+
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
1388+
if not is_connection_lost:
1389+
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
1390+
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
1391+
13401392
if attempt == 0:
1341-
# First attempt failed, try reconnecting
1342-
logger.info("MCP connection closed unexpectedly. Reconnecting...")
1393+
# First attempt failed, try reconnecting.
1394+
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
13431395
try:
13441396
await self.connect(reset=True)
1345-
continue # Retry the operation
1397+
continue
13461398
except Exception as reconn_ex:
13471399
raise ToolExecutionException(
13481400
"Failed to reconnect to MCP server.",
13491401
inner_exception=reconn_ex,
13501402
) from reconn_ex
1351-
else:
1352-
# Second attempt also failed, give up
1353-
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
1354-
raise ToolExecutionException(
1355-
f"Failed to call tool '{tool_name}' - connection lost.",
1356-
inner_exception=cl_ex,
1357-
) from cl_ex
1358-
except McpError as mcp_exc:
1359-
error_message = mcp_exc.error.message
1360-
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
1403+
1404+
# Second attempt also failed, give up.
1405+
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
1406+
raise ToolExecutionException(
1407+
f"Failed to call tool '{tool_name}' - connection lost.",
1408+
inner_exception=call_ex,
1409+
) from call_ex
13611410
except Exception as ex:
13621411
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
13631412
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
@@ -1718,10 +1767,11 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
17181767
Returns:
17191768
An async context manager for the streamable HTTP client transport.
17201769
"""
1721-
from httpx import AsyncClient, Request, Timeout
1770+
from httpx import URL, AsyncClient, Request, Timeout
17221771

17231772
http_client = self._httpx_client
17241773
if self._header_provider is not None:
1774+
target_origin = _url_origin(URL(self.url))
17251775
if http_client is None:
17261776
http_client = AsyncClient(
17271777
follow_redirects=True,
@@ -1732,6 +1782,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
17321782
if not hasattr(self, "_inject_headers_hook"):
17331783

17341784
async def _inject_headers(request: Request) -> None: # noqa: RUF029
1785+
if _url_origin(request.url) != target_origin:
1786+
return
17351787
headers = _mcp_call_headers.get({})
17361788
for key, value in headers.items():
17371789
request.headers[key] = value

0 commit comments

Comments
 (0)