Skip to content

Commit 251ebec

Browse files
authored
fix: drain in-flight responses on stdin EOF (#759)
1 parent e709d0d commit 251ebec

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

crates/rmcp/src/service.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@ where
773773
let mut transport = transport.into_transport();
774774
let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new();
775775
let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new();
776+
let mut response_send_tasks = tokio::task::JoinSet::<()>::new();
776777
#[derive(Debug)]
777778
enum SendTaskResult {
778779
Request {
@@ -884,7 +885,7 @@ where
884885
}
885886
let send = transport.send(m);
886887
let current_span = tracing::Span::current();
887-
tokio::spawn(async move {
888+
response_send_tasks.spawn(async move {
888889
let send_result = send.await;
889890
if let Err(error) = send_result {
890891
tracing::error!(%error, "fail to response message");
@@ -1032,6 +1033,44 @@ where
10321033
}
10331034
}
10341035
};
1036+
1037+
// Drain in-flight handler responses before closing the transport.
1038+
// When stdin EOF or cancellation arrives, spawned handler tasks may still
1039+
// be finishing. We need to:
1040+
// 1. Wait for response sends that were already spawned in the main loop
1041+
// 2. Drain any remaining handler responses from the channel
1042+
let drain_timeout = match &quit_reason {
1043+
QuitReason::Closed => Some(Duration::from_secs(5)),
1044+
QuitReason::Cancelled => Some(Duration::from_secs(2)),
1045+
_ => None,
1046+
};
1047+
if let Some(timeout_duration) = drain_timeout {
1048+
// Drop our sender so the channel closes once all handler task
1049+
// clones finish sending their responses (or are dropped).
1050+
drop(sink_proxy_tx);
1051+
let drain_result = tokio::time::timeout(timeout_duration, async {
1052+
// First, wait for any response sends already dispatched by the
1053+
// main loop (these hold transport write futures).
1054+
while let Some(result) = response_send_tasks.join_next().await {
1055+
if let Err(error) = result {
1056+
tracing::error!(%error, "response send task failed during drain");
1057+
}
1058+
}
1059+
// Then drain any handler responses still in the channel
1060+
// (handlers that finished after the loop broke).
1061+
while let Some(m) = sink_proxy_rx.recv().await {
1062+
if let Err(error) = transport.send(m).await {
1063+
tracing::error!(%error, "failed to send pending response during drain");
1064+
break;
1065+
}
1066+
}
1067+
})
1068+
.await;
1069+
if drain_result.is_err() {
1070+
tracing::warn!("timed out draining in-flight responses");
1071+
}
1072+
}
1073+
10351074
let sink_close_result = transport.close().await;
10361075
if let Err(e) = sink_close_result {
10371076
tracing::error!(%e, "fail to close sink");
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#![cfg(not(feature = "local"))]
2+
// cargo test --test test_inflight_response_drain --features "client server"
3+
4+
use std::{
5+
pin::Pin,
6+
sync::{
7+
Arc,
8+
atomic::{AtomicBool, Ordering},
9+
},
10+
task::{Context, Poll},
11+
time::Duration,
12+
};
13+
14+
use rmcp::{
15+
ServerHandler, ServiceExt,
16+
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
17+
model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo},
18+
service::QuitReason,
19+
tool, tool_handler, tool_router,
20+
};
21+
use tokio::io::{AsyncRead, ReadBuf};
22+
23+
// A slow tool server that sleeps before returning a response.
24+
#[derive(Debug, Clone)]
25+
struct SlowToolServer {
26+
tool_router: ToolRouter<Self>,
27+
}
28+
29+
impl SlowToolServer {
30+
fn new() -> Self {
31+
Self {
32+
tool_router: Self::tool_router(),
33+
}
34+
}
35+
}
36+
37+
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
38+
struct SlowToolRequest {
39+
#[schemars(description = "how long to sleep in milliseconds")]
40+
sleep_ms: u64,
41+
}
42+
43+
#[tool_router]
44+
impl SlowToolServer {
45+
#[tool(description = "A tool that sleeps then returns")]
46+
async fn slow_tool(
47+
&self,
48+
Parameters(SlowToolRequest { sleep_ms }): Parameters<SlowToolRequest>,
49+
) -> String {
50+
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
51+
format!("done after {}ms", sleep_ms)
52+
}
53+
}
54+
55+
#[tool_handler]
56+
impl ServerHandler for SlowToolServer {
57+
fn get_info(&self) -> ServerInfo {
58+
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
59+
}
60+
}
61+
62+
#[derive(Debug, Clone, Default)]
63+
struct DummyClientHandler;
64+
65+
impl rmcp::ClientHandler for DummyClientHandler {
66+
fn get_info(&self) -> ClientInfo {
67+
ClientInfo::default()
68+
}
69+
}
70+
71+
/// An `AsyncRead` wrapper that delegates to the inner reader until signalled,
72+
/// then returns EOF (read 0 bytes).
73+
struct ClosableReader<R> {
74+
inner: R,
75+
eof_flag: Arc<AtomicBool>,
76+
}
77+
78+
impl<R: AsyncRead + Unpin> AsyncRead for ClosableReader<R> {
79+
fn poll_read(
80+
mut self: Pin<&mut Self>,
81+
cx: &mut Context<'_>,
82+
buf: &mut ReadBuf<'_>,
83+
) -> Poll<std::io::Result<()>> {
84+
if self.eof_flag.load(Ordering::Acquire) {
85+
return Poll::Ready(Ok(()));
86+
}
87+
Pin::new(&mut self.inner).poll_read(cx, buf)
88+
}
89+
}
90+
91+
/// When the server's input stream returns EOF while a tool handler is still
92+
/// in-flight, the drain phase should flush pending responses before closing.
93+
#[tokio::test]
94+
async fn test_inflight_response_drain_on_eof() -> anyhow::Result<()> {
95+
// Two unidirectional channels:
96+
// client_write → server_read (client sends requests to server)
97+
// server_write → client_read (server sends responses to client)
98+
let (client_write, server_read) = tokio::io::duplex(4096);
99+
let (server_write, client_read) = tokio::io::duplex(4096);
100+
101+
// Wrap the server's read side so we can signal EOF from the test.
102+
let eof_flag = Arc::new(AtomicBool::new(false));
103+
let closable_read = ClosableReader {
104+
inner: server_read,
105+
eof_flag: eof_flag.clone(),
106+
};
107+
108+
let server_transport = (closable_read, server_write);
109+
let client_transport = (client_read, client_write);
110+
111+
// Start server with slow tool handler
112+
let server_handle = tokio::spawn(async move {
113+
let server = SlowToolServer::new();
114+
let running = server.serve(server_transport).await?;
115+
let reason = running.waiting().await?;
116+
assert!(
117+
matches!(reason, QuitReason::Closed),
118+
"expected Closed quit reason, got {:?}",
119+
reason,
120+
);
121+
anyhow::Ok(())
122+
});
123+
124+
// Start client
125+
let client = DummyClientHandler.serve(client_transport).await?;
126+
127+
// Call the slow tool (200ms sleep). Concurrently, signal the server's
128+
// read side to return EOF after the request has been sent but before
129+
// the handler finishes.
130+
let tool_future = client.call_tool(
131+
CallToolRequestParams::new("slow_tool").with_arguments(
132+
serde_json::json!({ "sleep_ms": 200 })
133+
.as_object()
134+
.unwrap()
135+
.clone(),
136+
),
137+
);
138+
139+
let (tool_result, _) = tokio::join!(tool_future, async {
140+
// Wait for the request to be sent and received by the server,
141+
// then signal EOF on the server's read side.
142+
tokio::time::sleep(Duration::from_millis(50)).await;
143+
eof_flag.store(true, Ordering::Release);
144+
});
145+
146+
// The tool result should still arrive thanks to the drain phase.
147+
let result = tool_result?;
148+
let text = result
149+
.content
150+
.first()
151+
.and_then(|c| c.raw.as_text())
152+
.map(|t| t.text.as_str())
153+
.expect("expected text content in tool result");
154+
assert_eq!(text, "done after 200ms");
155+
156+
server_handle.await??;
157+
Ok(())
158+
}

0 commit comments

Comments
 (0)