Skip to content

Commit ce57bfe

Browse files
committed
fix(server): return stdio parse errors
1 parent ac96f88 commit ce57bfe

3 files changed

Lines changed: 136 additions & 19 deletions

File tree

src/mcp/server/stdio.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ async def run_server():
1717
```
1818
"""
1919

20+
import json
21+
import re
2022
import sys
2123
from contextlib import asynccontextmanager
2224
from io import TextIOWrapper
25+
from typing import Any, cast
2326

2427
import anyio
2528
import anyio.lowlevel
@@ -28,6 +31,50 @@ async def run_server():
2831
from mcp.shared._context_streams import create_context_streams
2932
from mcp.shared.message import SessionMessage
3033

34+
_JSONRPC_ID_PATTERN = re.compile(r'"id"\s*:\s*(-?\d+|"[^"\\]*")')
35+
36+
37+
def _request_id_from_raw_message(line: str) -> types.RequestId | None:
38+
try:
39+
raw_message: Any = json.loads(line)
40+
except Exception:
41+
raw_message = None
42+
43+
if not isinstance(raw_message, dict):
44+
match = _JSONRPC_ID_PATTERN.search(line)
45+
if not match:
46+
return None
47+
48+
raw_request_id = match.group(1)
49+
if raw_request_id.startswith('"'):
50+
return json.loads(raw_request_id)
51+
return int(raw_request_id)
52+
53+
raw_message_dict = cast(dict[str, Any], raw_message)
54+
request_id = raw_message_dict.get("id")
55+
if isinstance(request_id, str) or type(request_id) is int:
56+
return request_id
57+
return None
58+
59+
60+
def _error_response_from_parse_failure(line: str, exc: Exception) -> SessionMessage:
61+
request_id = _request_id_from_raw_message(line)
62+
message = str(exc)
63+
if "Invalid JSON" in message:
64+
code = types.PARSE_ERROR
65+
prefix = "Parse error"
66+
else:
67+
code = types.INVALID_REQUEST
68+
prefix = "Invalid request"
69+
70+
return SessionMessage(
71+
types.JSONRPCError(
72+
jsonrpc="2.0",
73+
id=request_id,
74+
error=types.ErrorData(code=code, message=f"{prefix}: {message}"),
75+
)
76+
)
77+
3178

3279
@asynccontextmanager
3380
async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None):
@@ -53,7 +100,8 @@ async def stdin_reader():
53100
try:
54101
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
55102
except Exception as exc:
56-
await read_stream_writer.send(exc)
103+
error_response = _error_response_from_parse_failure(line, exc)
104+
await write_stream.send(error_response)
57105
continue
58106

59107
session_message = SessionMessage(message)

tests/interaction/transports/test_stdio.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
import tempfile
1818
from pathlib import Path
19+
from typing import TextIO, cast
1920

2021
import anyio
2122
import pytest
@@ -67,7 +68,8 @@ async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess(
6768
async def collect(params: LoggingMessageNotificationParams) -> None:
6869
received.append(params)
6970

70-
with tempfile.TemporaryFile(mode="w+") as errlog:
71+
with tempfile.TemporaryFile(mode="w+") as errlog_file:
72+
errlog = cast(TextIO, errlog_file)
7173
transport = stdio_client(
7274
StdioServerParameters(
7375
command=sys.executable,
@@ -98,9 +100,11 @@ async def collect(params: LoggingMessageNotificationParams) -> None:
98100
assert received == snapshot(
99101
[LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")]
100102
)
101-
# The server writes this line only after its run loop returns on stdin close: seeing it proves
102-
# a self-exit, not the terminate escalation. The capture itself proves stderr passthrough.
103-
assert captured_stderr == snapshot("stdio-echo: clean exit\n")
103+
# The server writes this line only after its run loop returns, which happens when stdin closes:
104+
# seeing it proves the process exited on its own rather than via the transport's terminate
105+
# escalation, without a timing-based assertion. The suffix check keeps the test stable if the
106+
# child interpreter emits dependency warnings before the server's own stderr line.
107+
assert captured_stderr.endswith("stdio-echo: clean exit\n")
104108

105109

106110
@requirement("transport:stdio:stream-purity")

tests/server/test_stdio.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import json
23
import sys
34
import threading
45
from collections.abc import AsyncIterator
@@ -9,9 +10,17 @@
910
import pytest
1011

1112
from mcp.server.mcpserver import MCPServer
12-
from mcp.server.stdio import stdio_server
13+
from mcp.server.stdio import _error_response_from_parse_failure, _request_id_from_raw_message, stdio_server
1314
from mcp.shared.message import SessionMessage
14-
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
15+
from mcp.types import (
16+
INVALID_REQUEST,
17+
PARSE_ERROR,
18+
JSONRPCError,
19+
JSONRPCMessage,
20+
JSONRPCRequest,
21+
JSONRPCResponse,
22+
jsonrpc_message_adapter,
23+
)
1524

1625

1726
@pytest.mark.anyio
@@ -68,10 +77,10 @@ async def test_stdio_server_round_trips_messages_over_injected_streams() -> None
6877

6978
@pytest.mark.anyio
7079
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None:
71-
"""Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream.
80+
"""Non-UTF-8 stdin bytes produce an error response without killing the stream.
7281
73-
Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream
74-
exception; subsequent valid messages are still processed.
82+
Invalid bytes are replaced with U+FFFD, then fail JSON parsing and are returned
83+
as a JSON-RPC parse error. Subsequent valid messages are still processed.
7584
"""
7685
# \xff\xfe are invalid UTF-8 start bytes.
7786
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
@@ -80,20 +89,76 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> Non
8089
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
8190
# stdio_server()'s default path wraps it with errors='replace'.
8291
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
83-
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
92+
stdout = io.StringIO()
8493

8594
with anyio.fail_after(5):
86-
async with stdio_server() as (read_stream, write_stream):
87-
await write_stream.aclose()
95+
async with stdio_server(stdout=anyio.AsyncFile(stdout)) as (read_stream, write_stream):
8896
async with read_stream: # pragma: no branch
89-
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
97+
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> error response on stdout
9098
first = await read_stream.receive()
91-
assert isinstance(first, Exception)
99+
assert isinstance(first, SessionMessage)
100+
assert first.message == valid
101+
102+
await write_stream.aclose()
103+
104+
stdout.seek(0)
105+
output = stdout.read()
106+
error = jsonrpc_message_adapter.validate_json(output.strip())
107+
assert isinstance(error, JSONRPCError)
108+
assert error.id is None
109+
assert error.error.code == PARSE_ERROR
110+
111+
112+
@pytest.mark.anyio
113+
async def test_stdio_server_parse_error_completes_id_bearing_request() -> None:
114+
params: object = {"leaf": True}
115+
for index in reversed(range(256)):
116+
params = {f"p{index}": params}
117+
line = json.dumps({"jsonrpc": "2.0", "id": 900256, "method": "ping", "params": params}) + "\n"
118+
119+
stdin = io.StringIO(line)
120+
stdout = io.StringIO()
121+
122+
with anyio.fail_after(5):
123+
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
124+
read_stream,
125+
write_stream,
126+
):
127+
async with read_stream:
128+
with pytest.raises(anyio.EndOfStream):
129+
await read_stream.receive()
130+
await write_stream.aclose()
131+
132+
stdout.seek(0)
133+
output_lines = stdout.readlines()
134+
assert len(output_lines) == 1
135+
136+
response = jsonrpc_message_adapter.validate_json(output_lines[0].strip())
137+
assert isinstance(response, JSONRPCError)
138+
assert response.id == 900256
139+
assert response.error.code == PARSE_ERROR
140+
assert "Parse error" in response.error.message
141+
142+
143+
def test_stdio_request_id_recovery_edges() -> None:
144+
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":"abc","method":"ping","params":[') == "abc"
145+
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":42,"method":"ping","params":[') == 42
146+
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":-7,"method":1}') == -7
147+
assert _request_id_from_raw_message('{"jsonrpc":"2.0","id":null,"method":1}') is None
148+
assert _request_id_from_raw_message("[]") is None
149+
150+
151+
def test_stdio_invalid_request_response_preserves_string_id() -> None:
152+
line = '{"jsonrpc":"2.0","id":"bad-method","method":1}'
153+
with pytest.raises(Exception) as exc_info:
154+
jsonrpc_message_adapter.validate_json(line)
155+
156+
response = _error_response_from_parse_failure(line, exc_info.value)
92157

93-
# Second line: valid message still comes through
94-
second = await read_stream.receive()
95-
assert isinstance(second, SessionMessage)
96-
assert second.message == valid
158+
assert isinstance(response.message, JSONRPCError)
159+
assert response.message.id == "bad-method"
160+
assert response.message.error.code == INVALID_REQUEST
161+
assert "Invalid request" in response.message.error.message
97162

98163

99164
class _KeepOpenBytesIO(io.BytesIO):

0 commit comments

Comments
 (0)