forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_lowlevel_exception_handling.py
More file actions
88 lines (65 loc) · 3.07 KB
/
test_lowlevel_exception_handling.py
File metadata and controls
88 lines (65 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from unittest.mock import AsyncMock, Mock
import anyio
import pytest
from mcp import types
from mcp.server.lowlevel.server import Server
from mcp.server.session import ServerSession
from mcp.shared.session import RequestResponder
@pytest.mark.anyio
async def test_exception_handling_with_raise_exceptions_true():
"""Test that exceptions are re-raised when raise_exceptions=True"""
server = Server("test-server")
session = Mock(spec=ServerSession)
session.send_log_message = AsyncMock()
test_exception = RuntimeError("Test error")
with pytest.raises(RuntimeError, match="Test error"):
await server._handle_message(test_exception, session, {}, raise_exceptions=True)
session.send_log_message.assert_called_once()
@pytest.mark.anyio
@pytest.mark.parametrize(
"exception_class,message",
[
(ValueError, "Test validation error"),
(RuntimeError, "Test runtime error"),
(KeyError, "Test key error"),
(Exception, "Basic error"),
],
)
async def test_exception_handling_with_raise_exceptions_false(exception_class: type[Exception], message: str):
"""Test that exceptions are logged when raise_exceptions=False"""
server = Server("test-server")
session = Mock(spec=ServerSession)
session.send_log_message = AsyncMock()
test_exception = exception_class(message)
await server._handle_message(test_exception, session, {}, raise_exceptions=False)
# Should send log message
session.send_log_message.assert_called_once()
call_args = session.send_log_message.call_args
assert call_args.kwargs["level"] == "error"
assert call_args.kwargs["data"] == "Internal Server Error"
assert call_args.kwargs["logger"] == "mcp.server.exception_handler"
@pytest.mark.anyio
@pytest.mark.parametrize("stream_error", [anyio.ClosedResourceError(), anyio.BrokenResourceError()])
async def test_exception_handling_ignores_closed_log_stream(stream_error: Exception):
"""Logging an exception should not crash shutdown if the write stream is already gone."""
server = Server("test-server")
session = Mock(spec=ServerSession)
session.send_log_message = AsyncMock(side_effect=stream_error)
await server._handle_message(RuntimeError("Test error"), session, {}, raise_exceptions=False)
session.send_log_message.assert_called_once()
@pytest.mark.anyio
async def test_normal_message_handling_not_affected():
"""Test that normal messages still work correctly"""
server = Server("test-server")
session = Mock(spec=ServerSession)
# Create a mock RequestResponder
responder = Mock(spec=RequestResponder)
responder.request = types.PingRequest(method="ping")
responder.__enter__ = Mock(return_value=responder)
responder.__exit__ = Mock(return_value=None)
# Mock the _handle_request method to avoid complex setup
server._handle_request = AsyncMock()
# Should handle normally without any exception handling
await server._handle_message(responder, session, {}, raise_exceptions=False)
# Verify _handle_request was called
server._handle_request.assert_called_once()