Skip to content

Commit 114226a

Browse files
committed
fix: retain event cache for completed request-wise channels
1 parent ea36e2b commit 114226a

2 files changed

Lines changed: 126 additions & 10 deletions

File tree

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,15 @@ impl LocalSessionWorker {
363363
if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
364364
{
365365
tracing::debug!(http_request_id, "close http request wise channel");
366-
if let Some(channel) = self.tx_router.remove(&http_request_id) {
367-
for resource in channel.resources {
366+
if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
367+
for resource in channel.resources.drain() {
368368
self.resource_router.remove(&resource);
369369
}
370+
// Replace the sender with a closed dummy so no new
371+
// messages are routed here, but the cache stays alive
372+
// for late resume requests.
373+
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
374+
channel.tx.tx = closed_tx;
370375
}
371376
}
372377
} else {
@@ -416,6 +421,7 @@ impl LocalSessionWorker {
416421
async fn establish_request_wise_channel(
417422
&mut self,
418423
) -> Result<StreamableHttpMessageReceiver, SessionError> {
424+
self.tx_router.retain(|_, rw| !rw.tx.tx.is_closed());
419425
let http_request_id = self.next_http_request_id();
420426
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
421427
let starting_index = usize::from(self.session_config.sse_retry.is_some());
@@ -535,24 +541,25 @@ impl LocalSessionWorker {
535541
match last_event_id.http_request_id {
536542
Some(http_request_id) => {
537543
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
538-
// Resume existing request-wise channel
539-
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
540-
let (tx, rx) = channel;
544+
let was_completed = request_wise.tx.tx.is_closed();
545+
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
541546
request_wise.tx.tx = tx;
542547
let index = last_event_id.index;
543-
// sync messages after index
544548
request_wise.tx.sync(index).await?;
549+
if was_completed {
550+
// Close the sender after replaying so the stream ends
551+
// instead of hanging indefinitely.
552+
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
553+
request_wise.tx.tx = closed_tx;
554+
}
545555
Ok(StreamableHttpMessageReceiver {
546556
http_request_id: Some(http_request_id),
547557
inner: rx,
548558
})
549559
} else {
550-
// Request-wise channel completed (POST response already delivered).
551-
// The client's EventSource is reconnecting after the POST SSE stream
552-
// ended. Fall through to common channel handling below.
553560
tracing::debug!(
554561
http_request_id,
555-
"Request-wise channel completed, falling back to common channel"
562+
"Request-wise channel not found, falling back to common channel"
556563
);
557564
self.resume_or_shadow_common(last_event_id.index).await
558565
}

crates/rmcp/tests/test_streamable_http_priming.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,115 @@ async fn test_request_wise_priming_includes_http_request_id() -> anyhow::Result<
197197
Ok(())
198198
}
199199

200+
#[tokio::test]
201+
async fn test_resume_after_request_wise_channel_completed() -> anyhow::Result<()> {
202+
let ct = CancellationToken::new();
203+
204+
let service: StreamableHttpService<Calculator, LocalSessionManager> =
205+
StreamableHttpService::new(
206+
|| Ok(Calculator::new()),
207+
Default::default(),
208+
StreamableHttpServerConfig::default()
209+
.with_sse_keep_alive(None)
210+
.with_cancellation_token(ct.child_token()),
211+
);
212+
213+
let router = axum::Router::new().nest_service("/mcp", service);
214+
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
215+
let addr = tcp_listener.local_addr()?;
216+
217+
let handle = tokio::spawn({
218+
let ct = ct.clone();
219+
async move {
220+
let _ = axum::serve(tcp_listener, router)
221+
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
222+
.await;
223+
}
224+
});
225+
226+
let client = reqwest::Client::new();
227+
228+
// Initialize session
229+
let response = client
230+
.post(format!("http://{addr}/mcp"))
231+
.header("Content-Type", "application/json")
232+
.header("Accept", "application/json, text/event-stream")
233+
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#)
234+
.send()
235+
.await?;
236+
assert_eq!(response.status(), 200);
237+
let session_id: SessionId = response.headers()["mcp-session-id"].to_str()?.into();
238+
239+
// Complete handshake
240+
let status = client
241+
.post(format!("http://{addr}/mcp"))
242+
.header("Content-Type", "application/json")
243+
.header("Accept", "application/json, text/event-stream")
244+
.header("mcp-session-id", session_id.to_string())
245+
.header("Mcp-Protocol-Version", "2025-06-18")
246+
.body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
247+
.send()
248+
.await?
249+
.status();
250+
assert_eq!(status, 202);
251+
252+
// Call a tool and consume the full response (channel completes)
253+
let body = client
254+
.post(format!("http://{addr}/mcp"))
255+
.header("Content-Type", "application/json")
256+
.header("Accept", "application/json, text/event-stream")
257+
.header("mcp-session-id", session_id.to_string())
258+
.header("Mcp-Protocol-Version", "2025-06-18")
259+
.body(r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"sum","arguments":{"a":1,"b":2}}}"#)
260+
.send()
261+
.await?
262+
.text()
263+
.await?;
264+
265+
let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect();
266+
assert!(
267+
events.len() >= 2,
268+
"expected priming + response, got: {body}"
269+
);
270+
assert!(events[0].contains("id: 0/0"));
271+
assert!(events[1].contains(r#""id":2"#));
272+
273+
// Resume with Last-Event-ID after the channel has completed.
274+
// The cached events should be replayed and the stream should end.
275+
let resume_response = client
276+
.get(format!("http://{addr}/mcp"))
277+
.header("Accept", "text/event-stream")
278+
.header("mcp-session-id", session_id.to_string())
279+
.header("Mcp-Protocol-Version", "2025-06-18")
280+
.header("last-event-id", "0/0")
281+
.timeout(std::time::Duration::from_secs(5))
282+
.send()
283+
.await?;
284+
assert_eq!(resume_response.status(), 200);
285+
286+
let resume_body = resume_response.text().await?;
287+
let resume_events: Vec<&str> = resume_body
288+
.split("\n\n")
289+
.filter(|e| !e.is_empty())
290+
.collect();
291+
assert!(
292+
!resume_events.is_empty(),
293+
"expected replayed events on resume, got empty"
294+
);
295+
296+
// The replayed event should contain the original response
297+
let replayed = resume_events[0];
298+
assert!(
299+
replayed.contains(r#""id":2"#),
300+
"replayed event should contain the tool response, got: {replayed}"
301+
);
302+
303+
ct.cancel();
304+
handle.await?;
305+
306+
Ok(())
307+
}
308+
200309
#[tokio::test]
201310
async fn test_priming_on_stream_close() -> anyhow::Result<()> {
202311
use std::sync::Arc;

0 commit comments

Comments
 (0)