Skip to content

Commit fb63f67

Browse files
committed
fix(streamable-http): handle json tool responses without hanging
1 parent 251ebec commit fb63f67

File tree

3 files changed

+184
-23
lines changed

3 files changed

+184
-23
lines changed

crates/rmcp/Cargo.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,13 @@ path = "tests/test_streamable_http_priming.rs"
216216

217217
[[test]]
218218
name = "test_streamable_http_json_response"
219-
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
219+
required-features = [
220+
"server",
221+
"client",
222+
"transport-streamable-http-server",
223+
"transport-streamable-http-client-reqwest",
224+
"reqwest",
225+
]
220226
path = "tests/test_streamable_http_json_response.rs"
221227

222228
[[test]]

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -602,28 +602,56 @@ where
602602
// JSON-direct mode: await the single response and return as
603603
// application/json, eliminating SSE framing overhead.
604604
// Allowed by MCP Streamable HTTP spec (2025-06-18).
605+
//
606+
// Tools may emit progress notifications before their
607+
// final response. In JSON-direct mode there is no
608+
// secondary channel for those notifications, so keep
609+
// draining until we receive the terminal response/error
610+
// message that should satisfy the HTTP request.
605611
let cancel = self.config.cancellation_token.child_token();
606-
match tokio::select! {
607-
res = receiver.recv() => res,
608-
_ = cancel.cancelled() => None,
609-
} {
610-
Some(message) => {
611-
tracing::trace!(?message);
612-
let body = serde_json::to_vec(&message).map_err(|e| {
613-
internal_error_response("serialize json response")(e)
614-
})?;
615-
Ok(Response::builder()
616-
.status(http::StatusCode::OK)
617-
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
618-
.body(Full::new(Bytes::from(body)).boxed())
619-
.expect("valid response"))
612+
loop {
613+
match tokio::select! {
614+
res = receiver.recv() => res,
615+
_ = cancel.cancelled() => None,
616+
} {
617+
Some(
618+
message @ (crate::model::ServerJsonRpcMessage::Response(_)
619+
| crate::model::ServerJsonRpcMessage::Error(_)),
620+
) => {
621+
tracing::trace!(?message);
622+
let body = serde_json::to_vec(&message).map_err(|e| {
623+
internal_error_response("serialize json response")(e)
624+
})?;
625+
break Ok(Response::builder()
626+
.status(http::StatusCode::OK)
627+
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
628+
.body(Full::new(Bytes::from(body)).boxed())
629+
.expect("valid response"));
630+
}
631+
Some(crate::model::ServerJsonRpcMessage::Notification(
632+
notification,
633+
)) => {
634+
tracing::debug!(
635+
?notification,
636+
"dropping server notification while awaiting JSON response"
637+
);
638+
}
639+
Some(crate::model::ServerJsonRpcMessage::Request(request)) => {
640+
tracing::warn!(
641+
?request,
642+
"cannot deliver server request over JSON-direct response"
643+
);
644+
break Err(unexpected_message_response("response or error"));
645+
}
646+
None => {
647+
break Err(internal_error_response("empty response")(
648+
std::io::Error::new(
649+
std::io::ErrorKind::UnexpectedEof,
650+
"no response message received from handler",
651+
),
652+
));
653+
}
620654
}
621-
None => Err(internal_error_response("empty response")(
622-
std::io::Error::new(
623-
std::io::ErrorKind::UnexpectedEof,
624-
"no response message received from handler",
625-
),
626-
)),
627655
}
628656
} else {
629657
// SSE mode (default): original behaviour preserved unchanged

crates/rmcp/tests/test_streamable_http_json_response.rs

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
#![cfg(not(feature = "local"))]
2-
use rmcp::transport::streamable_http_server::{
3-
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
2+
use std::time::Duration;
3+
4+
use futures::future::BoxFuture;
5+
use rmcp::{
6+
ServerHandler, ServiceExt,
7+
handler::server::{
8+
router::tool::ToolRoute,
9+
tool::{ToolCallContext, ToolRouter, schema_for_type},
10+
},
11+
model::{
12+
CallToolRequestParams, CallToolResult, Content, ProgressNotificationParam,
13+
ServerCapabilities, ServerInfo, Tool,
14+
},
15+
tool_handler,
16+
transport::{
17+
StreamableHttpClientTransport,
18+
streamable_http_client::StreamableHttpClientTransportConfig,
19+
streamable_http_server::{
20+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
21+
},
22+
},
423
};
524
use tokio_util::sync::CancellationToken;
625

@@ -76,6 +95,114 @@ async fn stateless_json_response_returns_application_json() -> anyhow::Result<()
7695
Ok(())
7796
}
7897

98+
#[derive(Debug, Default, serde::Deserialize, schemars::JsonSchema)]
99+
struct EmptyArgs {}
100+
101+
#[derive(Debug, Clone)]
102+
struct ProgressToolServer {
103+
tool_router: ToolRouter<Self>,
104+
}
105+
106+
impl ProgressToolServer {
107+
fn new() -> Self {
108+
Self {
109+
tool_router: ToolRouter::new().with_route(ToolRoute::new_dyn(
110+
Tool::new(
111+
"progress_then_result",
112+
"Emit a progress notification before returning",
113+
schema_for_type::<EmptyArgs>(),
114+
),
115+
|context: ToolCallContext<'_, Self>| -> BoxFuture<'_, _> {
116+
Box::pin(async move {
117+
let Some(progress_token) =
118+
context.request_context.meta.get_progress_token()
119+
else {
120+
return Err(rmcp::ErrorData::invalid_params(
121+
"missing progress token",
122+
None,
123+
));
124+
};
125+
126+
context
127+
.request_context
128+
.peer
129+
.notify_progress(ProgressNotificationParam::new(progress_token, 1.0))
130+
.await
131+
.map_err(|err| {
132+
rmcp::ErrorData::internal_error(
133+
format!("failed to send progress notification: {err}"),
134+
None,
135+
)
136+
})?;
137+
138+
Ok(CallToolResult::success(vec![Content::text("done")]))
139+
})
140+
},
141+
)),
142+
}
143+
}
144+
}
145+
146+
#[tool_handler]
147+
impl ServerHandler for ProgressToolServer {
148+
fn get_info(&self) -> ServerInfo {
149+
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
150+
}
151+
}
152+
153+
#[tokio::test]
154+
async fn stateless_json_response_waits_for_terminal_tool_response() -> anyhow::Result<()> {
155+
let ct = CancellationToken::new();
156+
let service: StreamableHttpService<ProgressToolServer, LocalSessionManager> =
157+
StreamableHttpService::new(
158+
|| Ok(ProgressToolServer::new()),
159+
Default::default(),
160+
StreamableHttpServerConfig {
161+
stateful_mode: false,
162+
json_response: true,
163+
sse_keep_alive: None,
164+
cancellation_token: ct.child_token(),
165+
..Default::default()
166+
},
167+
);
168+
169+
let router = axum::Router::new().nest_service("/mcp", service);
170+
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
171+
let addr = tcp_listener.local_addr()?;
172+
173+
let handle = tokio::spawn({
174+
let ct = ct.clone();
175+
async move {
176+
let _ = axum::serve(tcp_listener, router)
177+
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
178+
.await;
179+
}
180+
});
181+
182+
let transport = StreamableHttpClientTransport::from_config(
183+
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
184+
);
185+
let client = ().serve(transport).await?;
186+
187+
let result = tokio::time::timeout(
188+
Duration::from_secs(3),
189+
client.call_tool(CallToolRequestParams::new("progress_then_result")),
190+
)
191+
.await??;
192+
193+
let text = result
194+
.content
195+
.first()
196+
.and_then(|content| content.raw.as_text())
197+
.map(|text| text.text.as_str());
198+
assert_eq!(text, Some("done"));
199+
200+
let _ = client.cancel().await;
201+
ct.cancel();
202+
handle.await?;
203+
Ok(())
204+
}
205+
79206
#[tokio::test]
80207
async fn stateless_sse_mode_default_unchanged() -> anyhow::Result<()> {
81208
let ct = CancellationToken::new();

0 commit comments

Comments
 (0)