Skip to content

Commit 559120d

Browse files
committed
fix: track completed_at for cache eviction and resume
1 parent 114226a commit 559120d

3 files changed

Lines changed: 90 additions & 185 deletions

File tree

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::{
22
collections::{HashMap, HashSet, VecDeque},
33
num::ParseIntError,
4-
time::Duration,
4+
time::{Duration, Instant},
55
};
66

77
use futures::{Stream, StreamExt};
@@ -285,6 +285,7 @@ impl CachedTx {
285285
struct HttpRequestWise {
286286
resources: HashSet<ResourceKey>,
287287
tx: CachedTx,
288+
completed_at: Option<Instant>,
288289
}
289290

290291
type HttpRequestId = u64;
@@ -355,28 +356,27 @@ pub struct StreamableHttpMessageReceiver {
355356

356357
impl LocalSessionWorker {
357358
fn unregister_resource(&mut self, resource: &ResourceKey) {
358-
if let Some(http_request_id) = self.resource_router.remove(resource) {
359-
tracing::trace!(?resource, http_request_id, "unregister resource");
360-
if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
361-
// It's okey to do so, since we don't handle batch json rpc request anymore
362-
// and this can be refactored after the batch request is removed in the coming version.
363-
if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
364-
{
365-
tracing::debug!(http_request_id, "close http request wise channel");
366-
if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
367-
for resource in channel.resources.drain() {
368-
self.resource_router.remove(&resource);
369-
}
370-
// Replace the sender with a closed dummy so no new
371-
// messages are routed here, but the cache stays alive
372-
// for late resume requests.
373-
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
374-
channel.tx.tx = closed_tx;
375-
}
376-
}
377-
} else {
378-
tracing::warn!(http_request_id, "http request wise channel not found");
379-
}
359+
let Some(http_request_id) = self.resource_router.remove(resource) else {
360+
return;
361+
};
362+
tracing::trace!(?resource, http_request_id, "unregister resource");
363+
let Some(channel) = self.tx_router.get_mut(&http_request_id) else {
364+
tracing::warn!(http_request_id, "http request wise channel not found");
365+
return;
366+
};
367+
if !channel.resources.is_empty() && !matches!(resource, ResourceKey::McpRequestId(_)) {
368+
return;
369+
}
370+
tracing::debug!(http_request_id, "close http request wise channel");
371+
let resources: Vec<_> = channel.resources.drain().collect();
372+
channel.completed_at = Some(Instant::now());
373+
// Close the sender so the client's SSE stream ends,
374+
// but keep the entry so the cache is available for
375+
// late resume requests.
376+
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
377+
channel.tx.tx = closed_tx;
378+
for resource in resources {
379+
self.resource_router.remove(&resource);
380380
}
381381
}
382382
fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
@@ -413,6 +413,11 @@ impl LocalSessionWorker {
413413
self.unregister_resource(&resource);
414414
}
415415
}
416+
fn evict_expired_channels(&mut self) {
417+
let ttl = self.session_config.completed_cache_ttl;
418+
self.tx_router
419+
.retain(|_, rw| rw.completed_at.is_none_or(|at| at.elapsed() < ttl));
420+
}
416421
fn next_http_request_id(&mut self) -> HttpRequestId {
417422
let id = self.next_http_request_id;
418423
self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
@@ -421,7 +426,6 @@ impl LocalSessionWorker {
421426
async fn establish_request_wise_channel(
422427
&mut self,
423428
) -> Result<StreamableHttpMessageReceiver, SessionError> {
424-
self.tx_router.retain(|_, rw| !rw.tx.tx.is_closed());
425429
let http_request_id = self.next_http_request_id();
426430
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
427431
let starting_index = usize::from(self.session_config.sse_retry.is_some());
@@ -430,6 +434,7 @@ impl LocalSessionWorker {
430434
HttpRequestWise {
431435
resources: Default::default(),
432436
tx: CachedTx::new(tx, Some(http_request_id), starting_index),
437+
completed_at: None,
433438
},
434439
);
435440
tracing::debug!(http_request_id, "establish new request wise channel");
@@ -540,29 +545,25 @@ impl LocalSessionWorker {
540545

541546
match last_event_id.http_request_id {
542547
Some(http_request_id) => {
543-
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
544-
let was_completed = request_wise.tx.tx.is_closed();
545-
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
546-
request_wise.tx.tx = tx;
547-
let index = last_event_id.index;
548-
request_wise.tx.sync(index).await?;
549-
if was_completed {
550-
// Close the sender after replaying so the stream ends
551-
// instead of hanging indefinitely.
552-
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
553-
request_wise.tx.tx = closed_tx;
554-
}
555-
Ok(StreamableHttpMessageReceiver {
556-
http_request_id: Some(http_request_id),
557-
inner: rx,
558-
})
559-
} else {
560-
tracing::debug!(
561-
http_request_id,
562-
"Request-wise channel not found, falling back to common channel"
563-
);
564-
self.resume_or_shadow_common(last_event_id.index).await
548+
let request_wise = self
549+
.tx_router
550+
.get_mut(&http_request_id)
551+
.ok_or(SessionError::ChannelClosed(Some(http_request_id)))?;
552+
let is_completed = request_wise.completed_at.is_some();
553+
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
554+
request_wise.tx.tx = tx;
555+
let index = last_event_id.index;
556+
request_wise.tx.sync(index).await?;
557+
if is_completed {
558+
// Drop the sender after replaying so the stream ends
559+
// instead of hanging indefinitely.
560+
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
561+
request_wise.tx.tx = closed_tx;
565562
}
563+
Ok(StreamableHttpMessageReceiver {
564+
http_request_id: Some(http_request_id),
565+
inner: rx,
566+
})
566567
}
567568
None => self.resume_or_shadow_common(last_event_id.index).await,
568569
}
@@ -972,6 +973,7 @@ impl Worker for LocalSessionWorker {
972973
let ct = context.cancellation_token.clone();
973974
let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
974975
loop {
976+
self.evict_expired_channels();
975977
let keep_alive_timeout = tokio::time::sleep(keep_alive);
976978
let event = tokio::select! {
977979
event = self.event_rx.recv() => {
@@ -1098,12 +1100,17 @@ pub struct SessionConfig {
10981100
/// stream-identifying event ID to each request-wise SSE stream.
10991101
/// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
11001102
pub sse_retry: Option<Duration>,
1103+
/// How long to retain completed request-wise channel caches for late
1104+
/// resume requests. After this duration, completed entries are evicted
1105+
/// and resume will return an error. Default is 60 seconds.
1106+
pub completed_cache_ttl: Duration,
11011107
}
11021108

11031109
impl SessionConfig {
11041110
pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
11051111
pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
11061112
pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
1113+
pub const DEFAULT_COMPLETED_CACHE_TTL: Duration = Duration::from_secs(60);
11071114
}
11081115

11091116
impl Default for SessionConfig {
@@ -1112,6 +1119,7 @@ impl Default for SessionConfig {
11121119
channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
11131120
keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
11141121
sse_retry: Some(Self::DEFAULT_SSE_RETRY),
1122+
completed_cache_ttl: Self::DEFAULT_COMPLETED_CACHE_TTL,
11151123
}
11161124
}
11171125
}

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -478,40 +478,46 @@ where
478478
.and_then(|v| v.to_str().ok())
479479
.map(|s| s.to_owned());
480480
if let Some(last_event_id) = last_event_id {
481-
// check if session has this event id
482-
let stream = self
481+
match self
483482
.session_manager
484483
.resume(&session_id, last_event_id)
485484
.await
486-
.map_err(internal_error_response("resume session"))?;
487-
// Resume doesn't need priming - client already has the event ID
488-
Ok(sse_stream_response(
489-
stream,
490-
self.config.sse_keep_alive,
491-
self.config.cancellation_token.child_token(),
492-
))
493-
} else {
494-
// create standalone stream
495-
let stream = self
496-
.session_manager
497-
.create_standalone_stream(&session_id)
498-
.await
499-
.map_err(internal_error_response("create standalone stream"))?;
500-
// Prepend priming event if sse_retry configured
501-
let stream = if let Some(retry) = self.config.sse_retry {
502-
let priming = ServerSseMessage::priming("0", retry);
503-
futures::stream::once(async move { priming })
504-
.chain(stream)
505-
.left_stream()
506-
} else {
507-
stream.right_stream()
508-
};
509-
Ok(sse_stream_response(
510-
stream,
511-
self.config.sse_keep_alive,
512-
self.config.cancellation_token.child_token(),
513-
))
485+
{
486+
Ok(stream) => {
487+
return Ok(sse_stream_response(
488+
stream,
489+
self.config.sse_keep_alive,
490+
self.config.cancellation_token.child_token(),
491+
));
492+
}
493+
Err(e) => {
494+
// The referenced stream is gone (completed + evicted or
495+
// never existed). Fall through to create a fresh standalone
496+
// stream so EventSource auto-reconnection stays alive
497+
// without replaying events from a different stream.
498+
tracing::debug!("Resume failed ({e}), creating standalone stream");
499+
}
500+
}
514501
}
502+
// Create standalone stream (also the fallback for failed resume)
503+
let stream = self
504+
.session_manager
505+
.create_standalone_stream(&session_id)
506+
.await
507+
.map_err(internal_error_response("create standalone stream"))?;
508+
let stream = if let Some(retry) = self.config.sse_retry {
509+
let priming = ServerSseMessage::priming("0", retry);
510+
futures::stream::once(async move { priming })
511+
.chain(stream)
512+
.left_stream()
513+
} else {
514+
stream.right_stream()
515+
};
516+
Ok(sse_stream_response(
517+
stream,
518+
self.config.sse_keep_alive,
519+
self.config.cancellation_token.child_token(),
520+
))
515521
}
516522

517523
async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>

crates/rmcp/tests/test_streamable_http_priming.rs

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -197,115 +197,6 @@ async fn test_request_wise_priming_includes_http_request_id() -> anyhow::Result<
197197
Ok(())
198198
}
199199

200-
#[tokio::test]
201-
async fn test_resume_after_request_wise_channel_completed() -> anyhow::Result<()> {
202-
let ct = CancellationToken::new();
203-
204-
let service: StreamableHttpService<Calculator, LocalSessionManager> =
205-
StreamableHttpService::new(
206-
|| Ok(Calculator::new()),
207-
Default::default(),
208-
StreamableHttpServerConfig::default()
209-
.with_sse_keep_alive(None)
210-
.with_cancellation_token(ct.child_token()),
211-
);
212-
213-
let router = axum::Router::new().nest_service("/mcp", service);
214-
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
215-
let addr = tcp_listener.local_addr()?;
216-
217-
let handle = tokio::spawn({
218-
let ct = ct.clone();
219-
async move {
220-
let _ = axum::serve(tcp_listener, router)
221-
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
222-
.await;
223-
}
224-
});
225-
226-
let client = reqwest::Client::new();
227-
228-
// Initialize session
229-
let response = client
230-
.post(format!("http://{addr}/mcp"))
231-
.header("Content-Type", "application/json")
232-
.header("Accept", "application/json, text/event-stream")
233-
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#)
234-
.send()
235-
.await?;
236-
assert_eq!(response.status(), 200);
237-
let session_id: SessionId = response.headers()["mcp-session-id"].to_str()?.into();
238-
239-
// Complete handshake
240-
let status = client
241-
.post(format!("http://{addr}/mcp"))
242-
.header("Content-Type", "application/json")
243-
.header("Accept", "application/json, text/event-stream")
244-
.header("mcp-session-id", session_id.to_string())
245-
.header("Mcp-Protocol-Version", "2025-06-18")
246-
.body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
247-
.send()
248-
.await?
249-
.status();
250-
assert_eq!(status, 202);
251-
252-
// Call a tool and consume the full response (channel completes)
253-
let body = client
254-
.post(format!("http://{addr}/mcp"))
255-
.header("Content-Type", "application/json")
256-
.header("Accept", "application/json, text/event-stream")
257-
.header("mcp-session-id", session_id.to_string())
258-
.header("Mcp-Protocol-Version", "2025-06-18")
259-
.body(r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"sum","arguments":{"a":1,"b":2}}}"#)
260-
.send()
261-
.await?
262-
.text()
263-
.await?;
264-
265-
let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect();
266-
assert!(
267-
events.len() >= 2,
268-
"expected priming + response, got: {body}"
269-
);
270-
assert!(events[0].contains("id: 0/0"));
271-
assert!(events[1].contains(r#""id":2"#));
272-
273-
// Resume with Last-Event-ID after the channel has completed.
274-
// The cached events should be replayed and the stream should end.
275-
let resume_response = client
276-
.get(format!("http://{addr}/mcp"))
277-
.header("Accept", "text/event-stream")
278-
.header("mcp-session-id", session_id.to_string())
279-
.header("Mcp-Protocol-Version", "2025-06-18")
280-
.header("last-event-id", "0/0")
281-
.timeout(std::time::Duration::from_secs(5))
282-
.send()
283-
.await?;
284-
assert_eq!(resume_response.status(), 200);
285-
286-
let resume_body = resume_response.text().await?;
287-
let resume_events: Vec<&str> = resume_body
288-
.split("\n\n")
289-
.filter(|e| !e.is_empty())
290-
.collect();
291-
assert!(
292-
!resume_events.is_empty(),
293-
"expected replayed events on resume, got empty"
294-
);
295-
296-
// The replayed event should contain the original response
297-
let replayed = resume_events[0];
298-
assert!(
299-
replayed.contains(r#""id":2"#),
300-
"replayed event should contain the tool response, got: {replayed}"
301-
);
302-
303-
ct.cancel();
304-
handle.await?;
305-
306-
Ok(())
307-
}
308-
309200
#[tokio::test]
310201
async fn test_priming_on_stream_close() -> anyhow::Result<()> {
311202
use std::sync::Arc;

0 commit comments

Comments
 (0)