-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_session_group.py
More file actions
384 lines (338 loc) · 17.2 KB
/
test_session_group.py
File metadata and controls
384 lines (338 loc) · 17.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import contextlib
from unittest import mock
import httpx
import pytest
import mcp
from mcp import types
from mcp.client.session_group import (
ClientSessionGroup,
ClientSessionParameters,
SseServerParameters,
StreamableHttpParameters,
)
from mcp.client.stdio import StdioServerParameters
from mcp.shared.exceptions import McpError
@pytest.fixture
def mock_exit_stack():
"""Fixture for a mocked AsyncExitStack."""
# Use unittest.mock.Mock directly if needed, or just a plain object
# if only attribute access/existence is needed.
# For AsyncExitStack, Mock or MagicMock is usually fine.
return mock.MagicMock(spec=contextlib.AsyncExitStack)
@pytest.mark.anyio
class TestClientSessionGroup:
def test_init(self):
mcp_session_group = ClientSessionGroup()
assert not mcp_session_group._tools
assert not mcp_session_group._resources
assert not mcp_session_group._prompts
assert not mcp_session_group._tool_to_session
def test_component_properties(self):
# --- Mock Dependencies ---
mock_prompt = mock.Mock()
mock_resource = mock.Mock()
mock_tool = mock.Mock()
# --- Prepare Session Group ---
mcp_session_group = ClientSessionGroup()
mcp_session_group._prompts = {"my_prompt": mock_prompt}
mcp_session_group._resources = {"my_resource": mock_resource}
mcp_session_group._tools = {"my_tool": mock_tool}
# --- Assertions ---
assert mcp_session_group.prompts == {"my_prompt": mock_prompt}
assert mcp_session_group.resources == {"my_resource": mock_resource}
assert mcp_session_group.tools == {"my_tool": mock_tool}
async def test_call_tool(self):
# --- Mock Dependencies ---
mock_session = mock.AsyncMock()
# --- Prepare Session Group ---
def hook(name: str, server_info: types.Implementation) -> str: # pragma: no cover
return f"{(server_info.name)}-{name}"
mcp_session_group = ClientSessionGroup(component_name_hook=hook)
mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})}
mcp_session_group._tool_to_session = {"server1-my_tool": mock_session}
text_content = types.TextContent(type="text", text="OK")
mock_session.call_tool.return_value = types.CallToolResult(content=[text_content])
# --- Test Execution ---
result = await mcp_session_group.call_tool(
name="server1-my_tool",
arguments={
"name": "value1",
"args": {},
},
)
# --- Assertions ---
assert result.content == [text_content]
mock_session.call_tool.assert_called_once_with(
"my_tool",
arguments={"name": "value1", "args": {}},
read_timeout_seconds=None,
progress_callback=None,
meta=None,
)
async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack):
"""Test connecting to a server and aggregating components."""
# --- Mock Dependencies ---
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "TestServer1"
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_tool1 = mock.Mock(spec=types.Tool)
mock_tool1.name = "tool_a"
mock_resource1 = mock.Mock(spec=types.Resource)
mock_resource1.name = "resource_b"
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prompt_c"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1])
# --- Test Execution ---
group = ClientSessionGroup(exit_stack=mock_exit_stack)
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
assert mock_session in group._sessions
assert len(group.tools) == 1
assert "tool_a" in group.tools
assert group.tools["tool_a"] == mock_tool1
assert group._tool_to_session["tool_a"] == mock_session
assert len(group.resources) == 1
assert "resource_b" in group.resources
assert group.resources["resource_b"] == mock_resource1
assert len(group.prompts) == 1
assert "prompt_c" in group.prompts
assert group.prompts["prompt_c"] == mock_prompt1
mock_session.list_tools.assert_awaited_once()
mock_session.list_resources.assert_awaited_once()
mock_session.list_prompts.assert_awaited_once()
async def test_connect_to_server_with_name_hook(self, mock_exit_stack: contextlib.AsyncExitStack):
"""Test connecting with a component name hook."""
# --- Mock Dependencies ---
mock_server_info = mock.Mock(spec=types.Implementation)
mock_server_info.name = "HookServer"
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
mock_tool = mock.Mock(spec=types.Tool)
mock_tool.name = "base_tool"
mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool])
mock_session.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[])
# --- Test Setup ---
def name_hook(name: str, server_info: types.Implementation) -> str:
return f"{server_info.name}.{name}"
# --- Test Execution ---
group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook)
with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)):
await group.connect_to_server(StdioServerParameters(command="test"))
# --- Assertions ---
assert mock_session in group._sessions
assert len(group.tools) == 1
expected_tool_name = "HookServer.base_tool"
assert expected_tool_name in group.tools
assert group.tools[expected_tool_name] == mock_tool
assert group._tool_to_session[expected_tool_name] == mock_session
async def test_disconnect_from_server(self): # No mock arguments needed
"""Test disconnecting from a server."""
# --- Test Setup ---
group = ClientSessionGroup()
server_name = "ServerToDisconnect"
# Manually populate state using standard mocks
mock_session1 = mock.MagicMock(spec=mcp.ClientSession)
mock_session2 = mock.MagicMock(spec=mcp.ClientSession)
mock_tool1 = mock.Mock(spec=types.Tool)
mock_tool1.name = "tool1"
mock_resource1 = mock.Mock(spec=types.Resource)
mock_resource1.name = "res1"
mock_prompt1 = mock.Mock(spec=types.Prompt)
mock_prompt1.name = "prm1"
mock_tool2 = mock.Mock(spec=types.Tool)
mock_tool2.name = "tool2"
mock_component_named_like_server = mock.Mock()
mock_session = mock.Mock(spec=mcp.ClientSession)
group._tools = {
"tool1": mock_tool1,
"tool2": mock_tool2,
server_name: mock_component_named_like_server,
}
group._tool_to_session = {
"tool1": mock_session1,
"tool2": mock_session2,
server_name: mock_session1,
}
group._resources = {
"res1": mock_resource1,
server_name: mock_component_named_like_server,
}
group._prompts = {
"prm1": mock_prompt1,
server_name: mock_component_named_like_server,
}
group._sessions = {
mock_session: ClientSessionGroup._ComponentNames(
prompts=set({"prm1"}),
resources=set({"res1"}),
tools=set({"tool1", "tool2"}),
)
}
# --- Assertions ---
assert mock_session in group._sessions
assert "tool1" in group._tools
assert "tool2" in group._tools
assert "res1" in group._resources
assert "prm1" in group._prompts
# --- Test Execution ---
await group.disconnect_from_server(mock_session)
# --- Assertions ---
assert mock_session not in group._sessions
assert "tool1" not in group._tools
assert "tool2" not in group._tools
assert "res1" not in group._resources
assert "prm1" not in group._prompts
async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack: contextlib.AsyncExitStack):
"""Test McpError raised when connecting a server with a dup name."""
# --- Setup Pre-existing State ---
group = ClientSessionGroup(exit_stack=mock_exit_stack)
existing_tool_name = "shared_tool"
# Manually add a tool to simulate a previous connection
group._tools[existing_tool_name] = mock.Mock(spec=types.Tool)
group._tools[existing_tool_name].name = existing_tool_name
# Need a dummy session associated with the existing tool
mock_session = mock.MagicMock(spec=mcp.ClientSession)
group._tool_to_session[existing_tool_name] = mock_session
group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack)
# --- Mock New Connection Attempt ---
mock_server_info_new = mock.Mock(spec=types.Implementation)
mock_server_info_new.name = "ServerWithDuplicate"
mock_session_new = mock.AsyncMock(spec=mcp.ClientSession)
# Configure the new session to return a tool with the *same name*
duplicate_tool = mock.Mock(spec=types.Tool)
duplicate_tool.name = existing_tool_name
mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool])
# Keep other lists empty for simplicity
mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[])
mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[])
# --- Test Execution and Assertion ---
with pytest.raises(McpError) as excinfo:
with mock.patch.object(
group,
"_establish_session",
return_value=(mock_server_info_new, mock_session_new),
):
await group.connect_to_server(StdioServerParameters(command="test"))
# Assert details about the raised error
assert excinfo.value.error.code == types.INVALID_PARAMS
assert existing_tool_name in excinfo.value.error.message
assert "already exist " in excinfo.value.error.message
# Verify the duplicate tool was *not* added again (state should be unchanged)
assert len(group._tools) == 1 # Should still only have the original
assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock
# No patching needed here
async def test_disconnect_non_existent_server(self):
"""Test disconnecting a server that isn't connected."""
session = mock.Mock(spec=mcp.ClientSession)
group = ClientSessionGroup()
with pytest.raises(McpError):
await group.disconnect_from_server(session)
@pytest.mark.parametrize(
"server_params_instance, client_type_name, patch_target_for_client_func",
[
(
StdioServerParameters(command="test_stdio_cmd"),
"stdio",
"mcp.client.session_group.mcp.stdio_client",
),
(
SseServerParameters(url="http://test.com/sse", timeout=10.0),
"sse",
"mcp.client.session_group.sse_client",
), # url, headers, timeout, sse_read_timeout
(
StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False),
"streamablehttp",
"mcp.client.session_group.streamable_http_client",
), # url, headers, timeout, sse_read_timeout, terminate_on_close
],
)
async def test_establish_session_parameterized(
self,
server_params_instance: StdioServerParameters | SseServerParameters | StreamableHttpParameters,
client_type_name: str, # Just for clarity or conditional logic if needed
patch_target_for_client_func: str,
):
with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class:
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM")
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
# streamable_http_client's __aenter__ returns three values
if client_type_name == "streamablehttp":
mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra")
mock_client_cm_instance.__aenter__.return_value = (
mock_read_stream,
mock_write_stream,
mock_extra_stream_val,
)
else:
mock_client_cm_instance.__aenter__.return_value = (
mock_read_stream,
mock_write_stream,
)
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
mock_specific_client_func.return_value = mock_client_cm_instance
# --- Mock mcp.ClientSession (class) ---
# mock_ClientSession_class is already provided by the outer patch
mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
mock_ClientSession_class.return_value = mock_raw_session_cm
mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
# Mock session.initialize()
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1")
mock_entered_session.initialize.return_value = mock_initialize_result
# --- Test Execution ---
group = ClientSessionGroup()
returned_server_info = None
returned_session = None
async with contextlib.AsyncExitStack() as stack:
group._exit_stack = stack
(
returned_server_info,
returned_session,
) = await group._establish_session(server_params_instance, ClientSessionParameters())
# --- Assertions ---
# 1. Assert the correct specific client function was called
if client_type_name == "stdio":
assert isinstance(server_params_instance, StdioServerParameters)
mock_specific_client_func.assert_called_once_with(server_params_instance)
elif client_type_name == "sse":
assert isinstance(server_params_instance, SseServerParameters)
mock_specific_client_func.assert_called_once_with(
url=server_params_instance.url,
headers=server_params_instance.headers,
timeout=server_params_instance.timeout,
sse_read_timeout=server_params_instance.sse_read_timeout,
)
elif client_type_name == "streamablehttp": # pragma: no branch
assert isinstance(server_params_instance, StreamableHttpParameters)
# Verify streamable_http_client was called with url, httpx_client, and terminate_on_close
# The http_client is created by the real create_mcp_http_client
call_args = mock_specific_client_func.call_args
assert call_args.kwargs["url"] == server_params_instance.url
assert call_args.kwargs["terminate_on_close"] == server_params_instance.terminate_on_close
assert isinstance(call_args.kwargs["http_client"], httpx.AsyncClient)
mock_client_cm_instance.__aenter__.assert_awaited_once()
# 2. Assert ClientSession was called correctly
mock_ClientSession_class.assert_called_once_with(
mock_read_stream,
mock_write_stream,
read_timeout_seconds=None,
sampling_callback=None,
elicitation_callback=None,
list_roots_callback=None,
logging_callback=None,
message_handler=None,
client_info=None,
)
mock_raw_session_cm.__aenter__.assert_awaited_once()
mock_entered_session.initialize.assert_awaited_once()
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.serverInfo
assert returned_session is mock_entered_session