|
4 | 4 | from typing import Any |
5 | 5 |
|
6 | 6 | from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio, MCPServerStreamableHTTP |
| 7 | +from pydantic_ai.toolsets import AbstractToolset, PrefixedToolset |
7 | 8 | from sqlalchemy.ext.asyncio import AsyncSession |
8 | 9 |
|
9 | 10 | from backend.common.exception import errors |
|
13 | 14 | from backend.plugin.ai.model import Mcp |
14 | 15 | from backend.plugin.ai.schema.mcp import CreateMcpParam, UpdateMcpParam |
15 | 16 |
|
16 | | -McpToolset = MCPServerStdio | MCPServerSSE | MCPServerStreamableHTTP |
| 17 | +McpToolset = AbstractToolset[Any] |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class McpService: |
@@ -55,47 +56,43 @@ async def get_toolsets(*, db: AsyncSession, mcp_ids: list[int]) -> list[McpTools |
55 | 56 | mcps = await mcp_dao.get_by_ids(db, mcp_ids) |
56 | 57 | toolsets: list[McpToolset] = [] |
57 | 58 | for mcp in mcps: |
58 | | - headers = json.loads(mcp.headers) if isinstance(mcp.headers, str) else mcp.headers |
59 | | - if headers is not None and not isinstance(headers, dict): |
| 59 | + headers = json.loads(mcp.headers) if isinstance(mcp.headers, str) else (mcp.headers or {}) |
| 60 | + if not isinstance(headers, dict): |
60 | 61 | raise errors.RequestError(msg=f'MCP 请求头格式非法: {mcp.name}') |
61 | | - parsed_headers = None if headers is None else {str(key): str(value) for key, value in headers.items()} |
| 62 | + parsed_headers = {str(key): str(value) for key, value in headers.items()} |
62 | 63 | if mcp.type == McpType.stdio: |
63 | 64 | args = json.loads(mcp.args) if isinstance(mcp.args, str) else (mcp.args or []) |
64 | 65 | env = json.loads(mcp.env) if isinstance(mcp.env, str) else (mcp.env or {}) |
65 | 66 | if not isinstance(args, list): |
66 | 67 | raise errors.RequestError(msg=f'MCP 命令参数格式非法: {mcp.name}') |
67 | 68 | if not isinstance(env, dict): |
68 | 69 | raise errors.RequestError(msg=f'MCP 环境变量格式非法: {mcp.name}') |
69 | | - toolsets.append( |
70 | | - MCPServerStdio( |
71 | | - command=mcp.command, |
72 | | - args=[str(arg) for arg in args], |
73 | | - env={str(key): str(value) for key, value in env.items()}, |
74 | | - timeout=mcp.timeout, |
75 | | - ) |
| 70 | + toolset = MCPServerStdio( |
| 71 | + command=mcp.command, |
| 72 | + args=[str(arg) for arg in args], |
| 73 | + env={str(key): str(value) for key, value in env.items()}, |
| 74 | + timeout=mcp.timeout, |
76 | 75 | ) |
77 | 76 | elif mcp.type == McpType.sse: |
78 | 77 | if not mcp.url: |
79 | 78 | raise errors.RequestError(msg=f'MCP 缺少 SSE URL: {mcp.name}') |
80 | | - toolsets.append( |
81 | | - MCPServerSSE( |
82 | | - url=mcp.url, |
83 | | - headers=parsed_headers, |
84 | | - timeout=mcp.timeout, |
85 | | - read_timeout=mcp.read_timeout, |
86 | | - ) |
| 79 | + toolset = MCPServerSSE( |
| 80 | + url=mcp.url, |
| 81 | + headers=parsed_headers, |
| 82 | + timeout=mcp.timeout, |
| 83 | + read_timeout=mcp.read_timeout, |
87 | 84 | ) |
88 | 85 | else: |
89 | 86 | if not mcp.url: |
90 | 87 | raise errors.RequestError(msg=f'MCP 缺少 Streamable HTTP URL: {mcp.name}') |
91 | | - toolsets.append( |
92 | | - MCPServerStreamableHTTP( |
93 | | - url=mcp.url, |
94 | | - headers=parsed_headers, |
95 | | - timeout=mcp.timeout, |
96 | | - read_timeout=mcp.read_timeout, |
97 | | - ) |
| 88 | + toolset = MCPServerStreamableHTTP( |
| 89 | + url=mcp.url, |
| 90 | + headers=parsed_headers, |
| 91 | + timeout=mcp.timeout, |
| 92 | + read_timeout=mcp.read_timeout, |
98 | 93 | ) |
| 94 | + # 此举是为了为避免 MCP 工具名称冲突 |
| 95 | + toolsets.append(PrefixedToolset(toolset, prefix=f'mcp_{mcp.id}')) |
99 | 96 | return toolsets |
100 | 97 |
|
101 | 98 | @staticmethod |
|
0 commit comments