Skip to content

Commit 9421a2b

Browse files
committed
fix: harden MCPDriver based on human review + MCP 1.26 SDK pass
Blocker fixes: - tests/test_mcp_driver.py: strengthen integration test assertion for real add(3,4) result; assert structured_content.result==7 and content[0].text=='7' rather than bare 'is not None' - mcp.py/_run_with_retry: handle McpError (protocol-level rejection) immediately as DriverError; McpError is not retryable — the server processed and rejected the request Major fixes: - mcp.py/discover(): paginate tools/list via _fetch_all_tools; loop on nextCursor until exhausted to avoid silent capability truncation on large MCP servers - mcp_support.py/ToolSpec + discover(): forward ToolAnnotations hints (readOnlyHint, destructiveHint, idempotentHint) to ToolSpec; derive SafetyClass from them via _infer_safety_class(); safety_class_map still overrides. Eliminates always-READ misclassification of destructive tools - mcp_support.py/call_tool(): forward read_timeout_seconds from constraints to ClientSession.call_tool(); prevents indefinite hangs over HTTP - mcp.py/_run_with_retry: document at-least-once delivery semantics for HTTP transport; advise callers to set max_retries=0 for WRITE/DESTRUCTIVE - pyproject.toml: tighten mcp lower bound to >=1.6 in both [mcp] extra and [dev] extras; ToolAnnotations, outputSchema, and nextCursor require this release line Minor fixes: - mcp_support.py/ToolSpec: forward Tool.outputSchema for downstream use by firewall budget/redaction rules - docs/integrations.md: expand HTTP section with full async def main() discover() + register + route example; add at-least-once warning inline
1 parent a03f631 commit 9421a2b

5 files changed

Lines changed: 126 additions & 28 deletions

File tree

docs/integrations.md

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,40 @@ asyncio.run(main())
4747
### Streamable HTTP transport
4848

4949
```python
50+
import asyncio
51+
52+
from agent_kernel import CapabilityRegistry, Kernel, StaticRouter
5053
from agent_kernel.drivers.mcp import MCPDriver
5154

52-
http_driver = MCPDriver.from_http(
53-
url="https://example.com/mcp",
54-
server_name="remote-tools",
55-
max_retries=1,
56-
)
55+
56+
async def main() -> None:
57+
registry = CapabilityRegistry()
58+
router = StaticRouter(fallback=[])
59+
kernel = Kernel(registry=registry, router=router)
60+
61+
# Connect to a remote Streamable HTTP MCP server.
62+
# Note: max_retries > 0 creates at-least-once delivery semantics for
63+
# tools/call — if a connection drops after the server processes the
64+
# request but before the response arrives, the call will be repeated.
65+
# Ensure target tools are idempotent, or set max_retries=0 for
66+
# WRITE/DESTRUCTIVE capabilities.
67+
driver = MCPDriver.from_http(
68+
url="https://example.com/mcp",
69+
server_name="remote-tools",
70+
max_retries=1,
71+
)
72+
kernel.register_driver(driver)
73+
74+
# Discover tools and register them as capabilities.
75+
capabilities = await driver.discover(namespace="remote")
76+
registry.register_many(capabilities)
77+
78+
# Route each discovered capability to this MCP driver.
79+
for capability in capabilities:
80+
router.add_route(capability.capability_id, [driver.driver_id])
81+
82+
83+
asyncio.run(main())
5784
```
5885

5986
### Notes

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ dev = [
3838
"ruff>=0.4",
3939
"mypy>=1.10",
4040
"httpx>=0.27",
41-
"mcp>=1.0",
41+
"mcp>=1.6",
4242
]
43-
mcp = ["mcp>=1.0"]
43+
mcp = ["mcp>=1.6"]
4444
otel = ["opentelemetry-api>=1.20"]
4545

4646
[tool.hatch.build.targets.wheel]

src/agent_kernel/drivers/mcp.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,35 @@
1111
from .base import ExecutionContext
1212
from .mcp_support import (
1313
SessionFactory,
14+
ToolSpec,
1415
build_http_session_factory,
1516
build_stdio_session_factory,
1617
call_tool,
1718
extract_tool_specs,
1819
normalize_call_result,
1920
)
2021

22+
# Lazy import of McpError — only available when the mcp optional dep is installed.
23+
# If mcp is absent, factory methods raise ImportError before any session is created,
24+
# so _McpError will never be None on a live driver instance.
25+
try:
26+
from mcp.shared.exceptions import McpError as _McpError
27+
except ImportError: # pragma: no cover
28+
_McpError = None # type: ignore[assignment,misc]
29+
30+
31+
def _infer_safety_class(spec: ToolSpec) -> SafetyClass:
32+
"""Infer a SafetyClass from MCP ToolAnnotations hints.
33+
34+
Uses a conservative default of READ when annotations are absent.
35+
The caller's safety_class_map takes precedence over the inferred value.
36+
"""
37+
if spec.destructive_hint:
38+
return SafetyClass.DESTRUCTIVE
39+
if spec.read_only_hint:
40+
return SafetyClass.READ
41+
return SafetyClass.READ
42+
2143

2244
class MCPDriver:
2345
"""A driver that invokes capabilities via MCP tools/call."""
@@ -92,19 +114,20 @@ async def discover(
92114
namespace: str | None = None,
93115
safety_class_map: dict[str, SafetyClass] | None = None,
94116
) -> list[Capability]:
95-
"""Discover MCP tools and convert them to capabilities."""
96-
tool_list = await self._run_with_retry(
117+
"""Discover MCP tools across all pages and convert them to capabilities."""
118+
tools = await self._run_with_retry(
97119
operation_name="tools/list",
98-
action=lambda session: session.list_tools(),
120+
action=self._fetch_all_tools,
99121
)
100122

101123
capabilities: list[Capability] = []
102-
for spec in extract_tool_specs(tool_list):
124+
for spec in extract_tool_specs(tools):
103125
capability_id = f"{namespace}.{spec.name}" if namespace else spec.name
126+
inferred = _infer_safety_class(spec)
104127
safety_class = (
105-
safety_class_map.get(spec.name, SafetyClass.READ)
128+
safety_class_map.get(spec.name, inferred)
106129
if safety_class_map is not None
107-
else SafetyClass.READ
130+
else inferred
108131
)
109132
capabilities.append(
110133
Capability(
@@ -121,21 +144,42 @@ async def discover(
121144
)
122145
return capabilities
123146

147+
async def _fetch_all_tools(self, session: Any) -> list[Any]:
148+
"""Paginate tools/list to exhaustion and return a flat list of Tool objects."""
149+
all_tools: list[Any] = []
150+
cursor: str | None = None
151+
while True:
152+
result = await session.list_tools(cursor=cursor)
153+
all_tools.extend(getattr(result, "tools", []) or [])
154+
cursor = getattr(result, "nextCursor", None)
155+
if not cursor:
156+
break
157+
return all_tools
158+
124159
async def execute(self, ctx: ExecutionContext) -> RawResult:
125160
"""Execute an MCP tool call for the given capability context."""
126161
operation = str(ctx.args.get("operation", ctx.capability_id))
127162
params = {k: v for k, v in ctx.args.items() if k != "operation"}
128163

129164
# Apply policy constraints as default arguments, without overriding explicit args.
165+
# read_timeout_seconds is an SDK control parameter — applied to the session call
166+
# directly rather than forwarded to the tool as an argument.
167+
read_timeout_seconds_raw = ctx.constraints.get("read_timeout_seconds")
130168
for key, value in ctx.constraints.items():
131-
params.setdefault(key, value)
169+
if key != "read_timeout_seconds":
170+
params.setdefault(key, value)
171+
172+
read_timeout_seconds: float | None = (
173+
float(read_timeout_seconds_raw) if read_timeout_seconds_raw is not None else None
174+
)
132175

133176
result = await self._run_with_retry(
134177
operation_name=f"tools/call:{operation}",
135178
action=lambda session: call_tool(
136179
session,
137180
operation=operation,
138181
params=params,
182+
read_timeout_seconds=read_timeout_seconds,
139183
),
140184
)
141185

@@ -170,11 +214,19 @@ async def _run_with_retry(
170214
except DriverError:
171215
raise
172216
except Exception as exc:
173-
# Broad catch is intentional: exceptions at this level are
174-
# session/transport failures (connection refused, EOF, timeout).
175-
# MCP tool-level application errors are returned as isError=True
176-
# responses and converted to DriverError before reaching this
177-
# handler — they never appear as Python exceptions here.
217+
# McpError is a protocol-level rejection (tool not found, auth
218+
# failure, invalid params) — the server processed and rejected the
219+
# request. It is not retryable; surface it immediately as DriverError.
220+
if _McpError is not None and isinstance(exc, _McpError):
221+
raise DriverError(
222+
f"MCPDriver '{self._driver_id}' received a protocol error "
223+
f"during {operation_name}: {exc}"
224+
) from exc
225+
# All other exceptions are session/transport failures (connection
226+
# refused, EOF, timeout) and are retryable for HTTP transport.
227+
# Note: HTTP retries create at-least-once delivery semantics for
228+
# tools/call. Callers using WRITE/DESTRUCTIVE capabilities over HTTP
229+
# should ensure the target tool is idempotent, or set max_retries=0.
178230
last_exc = exc
179231

180232
reason = str(last_exc) if last_exc is not None else "unknown transport failure"

src/agent_kernel/drivers/mcp_support.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import AsyncIterator, Callable
77
from contextlib import AbstractAsyncContextManager, asynccontextmanager
88
from dataclasses import dataclass
9+
from datetime import timedelta
910
from typing import Any
1011

1112
from ..errors import DriverError
@@ -19,27 +20,42 @@ class ToolSpec:
1920

2021
name: str
2122
description: str
22-
23-
24-
async def call_tool(session: Any, *, operation: str, params: dict[str, Any]) -> Any:
23+
read_only_hint: bool = False
24+
destructive_hint: bool = False
25+
idempotent_hint: bool = False
26+
output_schema: dict[str, Any] | None = None
27+
28+
29+
async def call_tool(
30+
session: Any,
31+
*,
32+
operation: str,
33+
params: dict[str, Any],
34+
read_timeout_seconds: float | None = None,
35+
) -> Any:
2536
"""Call an MCP tool via tools/call."""
26-
return await session.call_tool(operation, arguments=params)
37+
timeout = timedelta(seconds=read_timeout_seconds) if read_timeout_seconds is not None else None
38+
return await session.call_tool(operation, arguments=params, read_timeout_seconds=timeout)
2739

2840

29-
def extract_tool_specs(tool_list_response: Any) -> list[ToolSpec]:
30-
"""Extract tool metadata from a tools/list response payload."""
31-
tools = getattr(tool_list_response, "tools", [])
41+
def extract_tool_specs(tools: list[Any]) -> list[ToolSpec]:
42+
"""Extract tool metadata from a flat list of MCP Tool objects."""
3243
if not isinstance(tools, list):
3344
return []
3445
specs: list[ToolSpec] = []
3546
for tool in tools:
3647
name = getattr(tool, "name", None)
3748
if not isinstance(name, str) or not name:
3849
continue
50+
ann = getattr(tool, "annotations", None)
3951
specs.append(
4052
ToolSpec(
4153
name=name,
4254
description=str(getattr(tool, "description", "") or ""),
55+
read_only_hint=bool(getattr(ann, "readOnlyHint", False)),
56+
destructive_hint=bool(getattr(ann, "destructiveHint", False)),
57+
idempotent_hint=bool(getattr(ann, "idempotentHint", False)),
58+
output_schema=getattr(tool, "outputSchema", None),
4359
)
4460
)
4561
return specs

tests/test_mcp_driver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ def __init__(
3737
self._call_error = call_error
3838
self.calls: list[tuple[str, dict[str, Any]]] = []
3939

40-
async def list_tools(self) -> ListToolsResult:
40+
async def list_tools(self, cursor: str | None = None) -> ListToolsResult:
4141
return ListToolsResult(tools=self._tools)
4242

4343
async def call_tool(
4444
self,
4545
operation: str,
4646
arguments: dict[str, Any],
47+
read_timeout_seconds: Any = None,
4748
) -> CallToolResult:
4849
self.calls.append((operation, arguments))
4950
if self._call_error is not None:
@@ -292,4 +293,6 @@ async def in_memory_factory() -> AsyncIterator[ClientSession]:
292293
args={"operation": "add", "a": 3, "b": 4},
293294
)
294295
result = await driver.execute(ctx)
295-
assert result.data is not None
296+
assert isinstance(result.data, dict)
297+
assert result.data["structured_content"]["result"] == 7
298+
assert result.data["content"][0]["text"] == "7"

0 commit comments

Comments
 (0)