Skip to content

Commit 48df0b1

Browse files
committed
Stabilize stdio smoke tests
1 parent fb3c724 commit 48df0b1

1 file changed

Lines changed: 76 additions & 11 deletions

File tree

tests/test_stdio_smoke.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import subprocess
1818
import sys
19+
import threading
1920
import time
2021
from pathlib import Path
2122

@@ -100,17 +101,17 @@ def _create_test_index(cache_dir: Path) -> Path:
100101
return db_path
101102

102103

103-
def _make_request(method: str, params: dict | None = None, req_id: int = 1) -> bytes:
104+
def _make_request(method: str, params: dict[str, object] | None = None, req_id: int = 1) -> bytes:
104105
"""Build a JSON-RPC 2.0 request as newline-terminated bytes."""
105-
msg = {"jsonrpc": "2.0", "id": req_id, "method": method}
106+
msg: dict[str, object] = {"jsonrpc": "2.0", "id": req_id, "method": method}
106107
if params is not None:
107108
msg["params"] = params
108109
return json.dumps(msg).encode() + b"\n"
109110

110111

111-
def _make_notification(method: str, params: dict | None = None) -> bytes:
112+
def _make_notification(method: str, params: dict[str, object] | None = None) -> bytes:
112113
"""Build a JSON-RPC 2.0 notification (no id) as newline-terminated bytes."""
113-
msg = {"jsonrpc": "2.0", "method": method}
114+
msg: dict[str, object] = {"jsonrpc": "2.0", "method": method}
114115
if params is not None:
115116
msg["params"] = params
116117
return json.dumps(msg).encode() + b"\n"
@@ -170,6 +171,19 @@ def _find_response(responses: list[dict], req_id: int) -> dict | None:
170171
return None
171172

172173

174+
def _request_ids(stdin_data: bytes) -> list[int]:
175+
"""Return request ids present in newline-delimited JSON-RPC input."""
176+
request_ids: list[int] = []
177+
for line in stdin_data.splitlines():
178+
try:
179+
parsed = json.loads(line)
180+
except json.JSONDecodeError:
181+
continue
182+
if isinstance(parsed, dict) and "id" in parsed:
183+
request_ids.append(parsed["id"])
184+
return request_ids
185+
186+
173187
class TestStdioSmoke:
174188
"""Spawn the MCP server as a subprocess and verify protocol compliance."""
175189

@@ -192,13 +206,64 @@ def _run_server_with_input(
192206
env=self.env,
193207
)
194208
assert proc.stdin is not None
195-
for index, line in enumerate(stdin_data.splitlines(keepends=True)):
196-
proc.stdin.write(line)
197-
proc.stdin.flush()
198-
time.sleep(0.3 if index == 0 else 0.05)
199-
proc.stdin.close()
200-
proc.stdin = None
201-
stdout, stderr = proc.communicate(timeout=timeout)
209+
assert proc.stdout is not None
210+
assert proc.stderr is not None
211+
212+
stdout_lines: list[bytes] = []
213+
stderr_lines: list[bytes] = []
214+
output_lock = threading.Lock()
215+
216+
def read_stream(stream, sink: list[bytes]) -> None:
217+
for line in iter(stream.readline, b""):
218+
with output_lock:
219+
sink.append(line)
220+
221+
stdout_thread = threading.Thread(
222+
target=read_stream,
223+
args=(proc.stdout, stdout_lines),
224+
daemon=True,
225+
)
226+
stderr_thread = threading.Thread(
227+
target=read_stream,
228+
args=(proc.stderr, stderr_lines),
229+
daemon=True,
230+
)
231+
stdout_thread.start()
232+
stderr_thread.start()
233+
234+
expected_ids = _request_ids(stdin_data)
235+
deadline = time.monotonic() + timeout
236+
try:
237+
for line in stdin_data.splitlines(keepends=True):
238+
proc.stdin.write(line)
239+
proc.stdin.flush()
240+
241+
while time.monotonic() < deadline:
242+
with output_lock:
243+
responses = _read_responses(b"".join(stdout_lines))
244+
if all(_find_response(responses, req_id) is not None for req_id in expected_ids):
245+
break
246+
if proc.poll() is not None:
247+
break
248+
time.sleep(0.02)
249+
finally:
250+
try:
251+
proc.stdin.close()
252+
except BrokenPipeError:
253+
pass
254+
proc.stdin = None
255+
256+
remaining = max(0.1, deadline - time.monotonic())
257+
try:
258+
proc.wait(timeout=remaining)
259+
except subprocess.TimeoutExpired:
260+
proc.kill()
261+
proc.wait(timeout=5)
262+
263+
stdout_thread.join(timeout=1)
264+
stderr_thread.join(timeout=1)
265+
stdout = b"".join(stdout_lines)
266+
stderr = b"".join(stderr_lines)
202267
return subprocess.CompletedProcess(proc.args, proc.returncode, stdout, stderr)
203268

204269
def test_server_lists_tools_no_stdout_pollution(self):

0 commit comments

Comments
 (0)