|
39 | 39 | from mcp import ClientSession, StdioServerParameters |
40 | 40 | from mcp.client.sse import sse_client |
41 | 41 | from mcp.client.stdio import stdio_client |
| 42 | +from mcp.client.streamable_http import streamablehttp_client |
42 | 43 |
|
43 | | -from data_designer.config.mcp import LocalStdioMCPProvider, MCPProviderT |
| 44 | +from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT |
44 | 45 | from data_designer.engine.mcp.errors import MCPToolError |
45 | 46 | from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult |
46 | 47 |
|
@@ -211,11 +212,15 @@ async def create_session() -> ClientSession: |
211 | 212 | env=provider.env, |
212 | 213 | ) |
213 | 214 | ctx = stdio_client(params) |
| 215 | + elif isinstance(provider, MCPProvider) and provider.provider_type == "streamable_http": |
| 216 | + headers = _build_auth_headers(provider.api_key) |
| 217 | + ctx = streamablehttp_client(provider.endpoint, headers=headers) |
214 | 218 | else: |
215 | 219 | headers = _build_auth_headers(provider.api_key) |
216 | 220 | ctx = sse_client(provider.endpoint, headers=headers) |
217 | 221 |
|
218 | | - read, write = await ctx.__aenter__() |
| 222 | + ctx_result = await ctx.__aenter__() |
| 223 | + read, write = ctx_result[0], ctx_result[1] |
219 | 224 | new_session = ClientSession(read, write) |
220 | 225 | await new_session.__aenter__() |
221 | 226 | await new_session.initialize() |
@@ -399,6 +404,11 @@ def list_tools(provider: MCPProviderT, timeout_sec: float | None = None) -> tupl |
399 | 404 | return _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec) |
400 | 405 |
|
401 | 406 |
|
| 407 | +def list_tool_names(provider: MCPProviderT, timeout_sec: float) -> list[str]: |
| 408 | + """Return the names of all tools available on an MCP provider.""" |
| 409 | + return [t.name for t in _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec)] |
| 410 | + |
| 411 | + |
402 | 412 | def call_tools( |
403 | 413 | calls: list[tuple[MCPProviderT, str, dict[str, Any]]], |
404 | 414 | *, |
@@ -434,7 +444,7 @@ def get_session_pool_info() -> dict[str, Any]: |
434 | 444 |
|
435 | 445 |
|
436 | 446 | def _build_auth_headers(api_key: str | None) -> dict[str, Any] | None: |
437 | | - """Build authentication headers for SSE client.""" |
| 447 | + """Build authentication headers for remote MCP clients.""" |
438 | 448 | if not api_key: |
439 | 449 | return None |
440 | 450 | return {"Authorization": f"Bearer {api_key}"} |
|
0 commit comments