-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_resource_cleanup.py
More file actions
124 lines (95 loc) · 4.57 KB
/
test_resource_cleanup.py
File metadata and controls
124 lines (95 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys
from collections.abc import Callable
from typing import Any
if sys.version_info >= (3, 11):
from builtins import BaseExceptionGroup # pragma: lax no cover
else:
from exceptiongroup import BaseExceptionGroup # pragma: lax no cover
from unittest.mock import patch
import anyio
import pytest
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import TypeAdapter
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamable_http_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestId, SendResultT
from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest
ClientTransport = tuple[
str,
Callable[..., Any],
Callable[[Any], tuple[MemoryObjectReceiveStream[Any], MemoryObjectSendStream[Any]]],
]
@pytest.mark.anyio
async def test_send_request_stream_cleanup():
"""Test that send_request properly cleans up streams when an exception occurs.
This test mocks out most of the session functionality to focus on stream cleanup.
"""
class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]):
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None: # pragma: no cover
pass
@property
def _receive_request_adapter(self) -> TypeAdapter[Any]:
return TypeAdapter(object) # pragma: no cover
@property
def _receive_notification_adapter(self) -> TypeAdapter[Any]:
return TypeAdapter(object) # pragma: no cover
write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
session = TestSession(read_stream_receive, write_stream_send)
request = PingRequest()
async def mock_send(*args: Any, **kwargs: Any):
raise RuntimeError("Simulated network error")
initial_stream_count = len(session._response_streams)
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError):
await session.send_request(request, EmptyResult)
assert len(session._response_streams) == initial_stream_count
await write_stream_send.aclose()
await write_stream_receive.aclose()
await read_stream_send.aclose()
await read_stream_receive.aclose()
@pytest.fixture(params=["sse", "streamable"])
def client_transport(
request: pytest.FixtureRequest, sse_server_url: str, streamable_server_url: str
) -> ClientTransport:
if request.param == "sse":
return (sse_server_url, sse_client, lambda x: (x[0], x[1]))
else:
return (streamable_server_url, streamable_http_client, lambda x: (x[0], x[1]))
@pytest.mark.anyio
async def test_generator_exit_on_gc_cleanup(client_transport: ClientTransport) -> None:
"""Suppress GeneratorExit from aclose() during GC cleanup (python/cpython#95571)."""
url, client_func, unpack = client_transport
cm = client_func(url)
result = await cm.__aenter__()
read_stream, write_stream = unpack(result)
await cm.gen.aclose()
await read_stream.aclose()
await write_stream.aclose()
@pytest.mark.anyio
async def test_generator_exit_in_exception_group(client_transport: ClientTransport) -> None:
"""Extract GeneratorExit from BaseExceptionGroup (python/cpython#135736)."""
url, client_func, unpack = client_transport
async with client_func(url) as result:
unpack(result)
raise BaseExceptionGroup("unhandled errors in a TaskGroup", [GeneratorExit()])
@pytest.mark.anyio
async def test_generator_exit_mixed_group(client_transport: ClientTransport) -> None:
"""Extract GeneratorExit from BaseExceptionGroup, re-raise other exceptions (python/cpython#135736)."""
url, client_func, unpack = client_transport
with pytest.raises(BaseExceptionGroup) as exc_info:
async with client_func(url) as result:
unpack(result)
raise BaseExceptionGroup("errors", [GeneratorExit(), ValueError("real error")])
def has_generator_exit(eg: BaseExceptionGroup[Any]) -> bool:
for e in eg.exceptions:
if isinstance(e, GeneratorExit):
return True # pragma: no cover
if isinstance(e, BaseExceptionGroup):
if has_generator_exit(eg=e): # type: ignore[arg-type]
return True # pragma: no cover
return False
assert not has_generator_exit(exc_info.value)