|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from typing import Any |
| 4 | + |
3 | 5 | import pytest |
4 | 6 | from logfire.testing import CaptureLogfire |
5 | 7 |
|
6 | 8 | from mcp import types |
7 | 9 | from mcp.client.client import Client |
| 10 | +from mcp.server.context import ServerRequestContext |
| 11 | +from mcp.server.lowlevel.server import Server |
8 | 12 | from mcp.server.mcpserver import MCPServer |
| 13 | +from mcp.shared.exceptions import MCPError |
9 | 14 |
|
10 | 15 | pytestmark = pytest.mark.anyio |
11 | 16 |
|
@@ -67,5 +72,53 @@ def greet(name: str) -> str: |
67 | 72 | assert session_metric["unit"] == "s" |
68 | 73 | [session_point] = session_metric["data"]["data_points"] |
69 | 74 | assert session_point["attributes"]["mcp.protocol.version"] == "2025-11-25" |
| 75 | + assert "error.type" not in session_point["attributes"] |
70 | 76 | assert session_point["count"] == 1 |
71 | 77 | assert session_point["sum"] > 0 |
| 78 | + |
| 79 | + |
| 80 | +@pytest.mark.filterwarnings("ignore::RuntimeWarning") |
| 81 | +async def test_server_operation_error_metrics(capfire: CaptureLogfire): |
| 82 | + """Verify that error.type and rpc.response.status_code are set when a handler raises MCPError.""" |
| 83 | + |
| 84 | + async def handle_call_tool( |
| 85 | + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams |
| 86 | + ) -> types.CallToolResult: |
| 87 | + raise MCPError(types.INVALID_PARAMS, "bad params") |
| 88 | + |
| 89 | + server = Server("test", on_call_tool=handle_call_tool) |
| 90 | + |
| 91 | + async with Client(server) as client: |
| 92 | + with pytest.raises(MCPError): |
| 93 | + await client.call_tool("boom", {}) |
| 94 | + |
| 95 | + metrics = {m["name"]: m for m in capfire.get_collected_metrics() if m["name"].startswith("mcp.")} |
| 96 | + op_points = metrics["mcp.server.operation.duration"]["data"]["data_points"] |
| 97 | + error_point = next(p for p in op_points if p["attributes"]["mcp.method.name"] == "tools/call") |
| 98 | + assert error_point["attributes"]["error.type"] == str(types.INVALID_PARAMS) |
| 99 | + assert error_point["attributes"]["rpc.response.status_code"] == str(types.INVALID_PARAMS) |
| 100 | + |
| 101 | + |
| 102 | +@pytest.mark.filterwarnings("ignore::RuntimeWarning") |
| 103 | +async def test_server_session_error_metrics(capfire: CaptureLogfire): |
| 104 | + """Verify that error.type is set on session duration when the session exits with an exception.""" |
| 105 | + |
| 106 | + async def handle_call_tool( |
| 107 | + ctx: ServerRequestContext[Any], params: types.CallToolRequestParams |
| 108 | + ) -> types.CallToolResult: |
| 109 | + raise RuntimeError("unexpected crash") |
| 110 | + |
| 111 | + server = Server("test", on_call_tool=handle_call_tool) |
| 112 | + |
| 113 | + # raise_exceptions=True lets the RuntimeError escape the handler and crash the session, |
| 114 | + # simulating what happens in production when an unhandled exception exits the session block. |
| 115 | + with pytest.raises(Exception): |
| 116 | + async with Client(server, raise_exceptions=True) as client: |
| 117 | + await client.call_tool("boom", {}) |
| 118 | + |
| 119 | + metrics = {m["name"]: m for m in capfire.get_collected_metrics() if m["name"].startswith("mcp.")} |
| 120 | + session_points = metrics["mcp.server.session.duration"]["data"]["data_points"] |
| 121 | + error_session_points = [p for p in session_points if "error.type" in p["attributes"]] |
| 122 | + assert len(error_session_points) >= 1 |
| 123 | + # anyio wraps task group exceptions in ExceptionGroup |
| 124 | + assert error_session_points[0]["attributes"]["error.type"] in ("RuntimeError", "ExceptionGroup") |
0 commit comments