-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Expand file tree
/
Copy pathtest_resource_cleanup.py
More file actions
67 lines (51 loc) · 2.5 KB
/
test_resource_cleanup.py
File metadata and controls
67 lines (51 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Any
from unittest.mock import patch
import anyio
import pytest
from pydantic import TypeAdapter
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestId, SendResultT
from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest
@pytest.mark.anyio
async def test_send_request_stream_cleanup():
"""Test that send_request properly cleans up streams when an exception occurs.
This test mocks out most of the session functionality to focus on stream cleanup.
"""
# Create a mock session with the minimal required functionality
class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]):
async def _send_response(
self, request_id: RequestId, response: SendResultT | ErrorData
) -> None: # pragma: no cover
pass
@property
def _receive_request_adapter(self) -> TypeAdapter[Any]:
return TypeAdapter(object) # pragma: no cover
@property
def _receive_notification_adapter(self) -> TypeAdapter[Any]:
return TypeAdapter(object) # pragma: no cover
# Create streams
write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1)
# Create the session
session = TestSession(read_stream_receive, write_stream_send)
# Create a test request
request = PingRequest()
# Patch the _write_stream.send method to raise an exception
async def mock_send(*args: Any, **kwargs: Any):
raise RuntimeError("Simulated network error")
# Record the response streams before the test
initial_stream_count = len(session._response_streams)
# Run the test with the patched method
async with session:
with patch.object(session._write_stream, "send", mock_send):
with pytest.raises(RuntimeError, match="Simulated network error"): # pragma: no branch
await session.send_request(request, EmptyResult)
# Verify that no response streams were leaked
assert len(session._response_streams) == initial_stream_count, (
f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}"
)
# Clean up
await write_stream_send.aclose()
await write_stream_receive.aclose()
await read_stream_send.aclose()
await read_stream_receive.aclose()