-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathtest_mcp_client.py
More file actions
127 lines (101 loc) · 4.16 KB
/
test_mcp_client.py
File metadata and controls
127 lines (101 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations
import importlib.util
import logging
import sys
import types
from pathlib import Path
from typing import Generic, TypeVar
from unittest.mock import AsyncMock
import pytest
REPO_ROOT = Path(__file__).resolve().parents[2]
MCP_CLIENT_MODULE_PATH = REPO_ROOT / "astrbot/core/agent/mcp_client.py"
def load_mcp_client_module():
package_names = [
"astrbot",
"astrbot.core",
"astrbot.core.agent",
"astrbot.core.utils",
]
for name in package_names:
if name not in sys.modules:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module
astrbot_module = sys.modules["astrbot"]
astrbot_module.logger = logging.getLogger("astrbot-test")
log_pipe_module = types.ModuleType("astrbot.core.utils.log_pipe")
log_pipe_module.LogPipe = type("LogPipe", (), {})
sys.modules[log_pipe_module.__name__] = log_pipe_module
run_context_module = types.ModuleType("astrbot.core.agent.run_context")
run_context_module.TContext = TypeVar("TContext")
class ContextWrapper(Generic[run_context_module.TContext]):
pass
run_context_module.ContextWrapper = ContextWrapper
sys.modules[run_context_module.__name__] = run_context_module
tool_module = types.ModuleType("astrbot.core.agent.tool")
tool_module.FunctionTool = type("FunctionTool", (), {})
sys.modules[tool_module.__name__] = tool_module
anyio_module = types.ModuleType("anyio")
anyio_module.ClosedResourceError = type("ClosedResourceError", (Exception,), {})
sys.modules["anyio"] = anyio_module
mcp_module = types.ModuleType("mcp")
mcp_module.Tool = type("Tool", (), {})
mcp_module.ClientSession = type("ClientSession", (), {})
mcp_module.ListToolsResult = type("ListToolsResult", (), {})
mcp_module.StdioServerParameters = type("StdioServerParameters", (), {})
mcp_module.stdio_client = lambda *args, **kwargs: None
mcp_module.types = types.SimpleNamespace(
LoggingMessageNotificationParams=type(
"LoggingMessageNotificationParams", (), {}
),
CallToolResult=type("CallToolResult", (), {}),
)
sys.modules["mcp"] = mcp_module
mcp_client_module = types.ModuleType("mcp.client")
sys.modules[mcp_client_module.__name__] = mcp_client_module
mcp_client_sse_module = types.ModuleType("mcp.client.sse")
mcp_client_sse_module.sse_client = lambda *args, **kwargs: None
sys.modules[mcp_client_sse_module.__name__] = mcp_client_sse_module
mcp_client_streamable_http_module = types.ModuleType(
"mcp.client.streamable_http"
)
mcp_client_streamable_http_module.streamablehttp_client = (
lambda *args, **kwargs: None
)
sys.modules[mcp_client_streamable_http_module.__name__] = (
mcp_client_streamable_http_module
)
spec = importlib.util.spec_from_file_location(
"astrbot.core.agent.mcp_client", MCP_CLIENT_MODULE_PATH
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def test_sanitize_mcp_arguments_removes_nested_empty_collections():
mcp_client_module = load_mcp_client_module()
sanitized = mcp_client_module._sanitize_mcp_arguments(
{
"query": "hello",
"filters": {"tags": [], "scope": {}},
"metadata": {"owner": "", "visibility": None},
}
)
assert sanitized == {"query": "hello"}
@pytest.mark.asyncio
async def test_call_tool_with_reconnect_falls_back_to_empty_top_level_arguments():
mcp_client_module = load_mcp_client_module()
client = mcp_client_module.MCPClient()
client.session = types.SimpleNamespace(call_tool=AsyncMock(return_value="ok"))
result = await client.call_tool_with_reconnect(
tool_name="search",
arguments={"filters": {}, "query": ""},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)
assert result == "ok"
client.session.call_tool.assert_awaited_once_with(
name="search",
arguments={},
read_timeout_seconds=mcp_client_module.timedelta(seconds=1),
)