Skip to content

Commit 114fc84

Browse files
committed
fix(http): drain SSE stream for connection reuse
1 parent 8e22aa2 commit 114fc84

3 files changed

Lines changed: 159 additions & 5 deletions

File tree

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,14 +303,26 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
303303
let Some(message) = message.transpose()? else {
304304
break;
305305
};
306-
let is_response = matches!(message, ServerJsonRpcMessage::Response(_));
306+
let is_response = matches!(
307+
message,
308+
ServerJsonRpcMessage::Response(_) | ServerJsonRpcMessage::Error(_)
309+
);
307310
let yield_result = sse_worker_tx.send(message).await;
308311
if yield_result.is_err() {
309312
tracing::trace!("streamable http transport worker dropped, exiting");
310313
break;
311314
}
312315
if close_on_response && is_response {
313-
tracing::debug!("got response, closing sse stream");
316+
tracing::debug!("got response, draining sse stream for connection reuse");
317+
// Drain remaining stream bytes so the HTTP/1.1 connection can
318+
// be returned to the pool instead of being discarded. The
319+
// server closes the channel shortly after sending the response,
320+
// so this normally completes in microseconds on localhost. The
321+
// timeout guards against servers that keep the stream open.
322+
let _ = tokio::time::timeout(std::time::Duration::from_millis(50), async {
323+
while sse_stream.next().await.is_some() {}
324+
})
325+
.await;
314326
break;
315327
}
316328
}

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ impl LocalSessionWorker {
479479
{
480480
OutboundChannel::RequestWise {
481481
id: *id,
482-
close: false,
482+
close: true,
483483
}
484484
} else {
485485
OutboundChannel::Common
@@ -492,7 +492,7 @@ impl LocalSessionWorker {
492492
{
493493
OutboundChannel::RequestWise {
494494
id: *id,
495-
close: false,
495+
close: true,
496496
}
497497
} else {
498498
OutboundChannel::Common
@@ -510,7 +510,11 @@ impl LocalSessionWorker {
510510
if let Some(request_wise) = self.tx_router.get_mut(&id) {
511511
request_wise.tx.send(message).await;
512512
if close {
513-
self.tx_router.remove(&id);
513+
if let Some(channel) = self.tx_router.remove(&id) {
514+
for resource in channel.resources {
515+
self.resource_router.remove(&resource);
516+
}
517+
}
514518
}
515519
} else {
516520
return Err(SessionError::ChannelClosed(Some(id)));
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#![cfg(all(
2+
feature = "transport-streamable-http-client",
3+
feature = "transport-streamable-http-client-reqwest",
4+
feature = "transport-streamable-http-server",
5+
not(feature = "local")
6+
))]
7+
8+
use std::time::Instant;
9+
10+
use rmcp::{
11+
ServerHandler, ServiceExt,
12+
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
13+
model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo},
14+
schemars, tool, tool_handler, tool_router,
15+
transport::{
16+
StreamableHttpClientTransport,
17+
streamable_http_client::StreamableHttpClientTransportConfig,
18+
streamable_http_server::{
19+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
20+
},
21+
},
22+
};
23+
use tokio_util::sync::CancellationToken;
24+
25+
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
26+
struct SumRequest {
27+
a: i32,
28+
b: i32,
29+
}
30+
31+
#[derive(Debug, Clone)]
32+
struct EchoServer {
33+
tool_router: ToolRouter<Self>,
34+
}
35+
36+
impl EchoServer {
37+
fn new() -> Self {
38+
Self {
39+
tool_router: Self::tool_router(),
40+
}
41+
}
42+
}
43+
44+
#[tool_router]
45+
impl EchoServer {
46+
#[tool(description = "Sum two numbers")]
47+
fn sum(&self, Parameters(SumRequest { a, b }): Parameters<SumRequest>) -> String {
48+
(a + b).to_string()
49+
}
50+
}
51+
52+
#[tool_handler(router = self.tool_router)]
53+
impl ServerHandler for EchoServer {
54+
fn get_info(&self) -> ServerInfo {
55+
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
56+
}
57+
}
58+
59+
/// Verify that subsequent tool calls do not regress in latency due to
60+
/// HTTP/1.1 connection pool exhaustion. Before the fix, each POST SSE
61+
/// response was dropped without fully consuming the body, preventing
62+
/// connection reuse and forcing a new TCP connection (~40 ms) per call.
63+
#[tokio::test]
64+
async fn test_subsequent_tool_calls_reuse_connections() -> anyhow::Result<()> {
65+
let ct = CancellationToken::new();
66+
67+
let service: StreamableHttpService<EchoServer, LocalSessionManager> =
68+
StreamableHttpService::new(
69+
|| Ok(EchoServer::new()),
70+
Default::default(),
71+
StreamableHttpServerConfig::default()
72+
.with_sse_keep_alive(None)
73+
.with_cancellation_token(ct.child_token()),
74+
);
75+
76+
let router = axum::Router::new().nest_service("/mcp", service);
77+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
78+
let addr = listener.local_addr()?;
79+
80+
let server_handle = tokio::spawn({
81+
let ct = ct.clone();
82+
async move {
83+
let _ = axum::serve(listener, router)
84+
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
85+
.await;
86+
}
87+
});
88+
89+
let transport = StreamableHttpClientTransport::from_config(
90+
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
91+
);
92+
let client = ClientInfo::default().serve(transport).await?;
93+
94+
// Warm up: first call may include one-time setup costs.
95+
let args: serde_json::Map<String, serde_json::Value> =
96+
serde_json::from_value(serde_json::json!({"a": 1, "b": 2}))?;
97+
let _ = client
98+
.call_tool(CallToolRequestParams::new("sum").with_arguments(args))
99+
.await?;
100+
101+
// Measure subsequent calls.
102+
let mut durations = Vec::new();
103+
for i in 0..5i32 {
104+
let args: serde_json::Map<String, serde_json::Value> =
105+
serde_json::from_value(serde_json::json!({"a": i, "b": i + 1}))?;
106+
let start = Instant::now();
107+
let result = client
108+
.call_tool(CallToolRequestParams::new("sum").with_arguments(args))
109+
.await?;
110+
let elapsed = start.elapsed();
111+
durations.push(elapsed);
112+
113+
assert!(
114+
result.is_error != Some(true),
115+
"tool call should succeed, got error: {:?}",
116+
result.content
117+
);
118+
}
119+
120+
let _ = client.cancel().await;
121+
ct.cancel();
122+
server_handle.await?;
123+
124+
// With connection reuse, localhost calls should complete well under 20 ms.
125+
// Before the fix, they consistently took ~42 ms due to new TCP connections.
126+
let max_allowed = std::time::Duration::from_millis(20);
127+
for (i, d) in durations.iter().enumerate() {
128+
assert!(
129+
*d < max_allowed,
130+
"call {} took {:?}, expected < {:?} (connection reuse may be broken)",
131+
i + 1,
132+
d,
133+
max_allowed,
134+
);
135+
}
136+
137+
Ok(())
138+
}

0 commit comments

Comments
 (0)