11use std:: {
22 collections:: { HashMap , HashSet , VecDeque } ,
33 num:: ParseIntError ,
4- time:: Duration ,
4+ time:: { Duration , Instant } ,
55} ;
66
77use futures:: { Stream , StreamExt } ;
@@ -285,6 +285,7 @@ impl CachedTx {
285285struct HttpRequestWise {
286286 resources : HashSet < ResourceKey > ,
287287 tx : CachedTx ,
288+ completed_at : Option < Instant > ,
288289}
289290
290291type HttpRequestId = u64 ;
@@ -355,28 +356,27 @@ pub struct StreamableHttpMessageReceiver {
355356
356357impl LocalSessionWorker {
357358 fn unregister_resource ( & mut self , resource : & ResourceKey ) {
358- if let Some ( http_request_id) = self . resource_router . remove ( resource) {
359- tracing:: trace!( ?resource, http_request_id, "unregister resource" ) ;
360- if let Some ( channel) = self . tx_router . get_mut ( & http_request_id) {
361- // It's okey to do so, since we don't handle batch json rpc request anymore
362- // and this can be refactored after the batch request is removed in the coming version.
363- if channel. resources . is_empty ( ) || matches ! ( resource, ResourceKey :: McpRequestId ( _) )
364- {
365- tracing:: debug!( http_request_id, "close http request wise channel" ) ;
366- if let Some ( channel) = self . tx_router . get_mut ( & http_request_id) {
367- for resource in channel. resources . drain ( ) {
368- self . resource_router . remove ( & resource) ;
369- }
370- // Replace the sender with a closed dummy so no new
371- // messages are routed here, but the cache stays alive
372- // for late resume requests.
373- let ( closed_tx, _) = tokio:: sync:: mpsc:: channel ( 1 ) ;
374- channel. tx . tx = closed_tx;
375- }
376- }
377- } else {
378- tracing:: warn!( http_request_id, "http request wise channel not found" ) ;
379- }
359+ let Some ( http_request_id) = self . resource_router . remove ( resource) else {
360+ return ;
361+ } ;
362+ tracing:: trace!( ?resource, http_request_id, "unregister resource" ) ;
363+ let Some ( channel) = self . tx_router . get_mut ( & http_request_id) else {
364+ tracing:: warn!( http_request_id, "http request wise channel not found" ) ;
365+ return ;
366+ } ;
367+ if !channel. resources . is_empty ( ) && !matches ! ( resource, ResourceKey :: McpRequestId ( _) ) {
368+ return ;
369+ }
370+ tracing:: debug!( http_request_id, "close http request wise channel" ) ;
371+ let resources: Vec < _ > = channel. resources . drain ( ) . collect ( ) ;
372+ channel. completed_at = Some ( Instant :: now ( ) ) ;
373+ // Close the sender so the client's SSE stream ends,
374+ // but keep the entry so the cache is available for
375+ // late resume requests.
376+ let ( closed_tx, _) = tokio:: sync:: mpsc:: channel ( 1 ) ;
377+ channel. tx . tx = closed_tx;
378+ for resource in resources {
379+ self . resource_router . remove ( & resource) ;
380380 }
381381 }
382382 fn register_resource ( & mut self , resource : ResourceKey , http_request_id : HttpRequestId ) {
@@ -413,6 +413,11 @@ impl LocalSessionWorker {
413413 self . unregister_resource ( & resource) ;
414414 }
415415 }
416+ fn evict_expired_channels ( & mut self ) {
417+ let ttl = self . session_config . completed_cache_ttl ;
418+ self . tx_router
419+ . retain ( |_, rw| rw. completed_at . is_none_or ( |at| at. elapsed ( ) < ttl) ) ;
420+ }
416421 fn next_http_request_id ( & mut self ) -> HttpRequestId {
417422 let id = self . next_http_request_id ;
418423 self . next_http_request_id = self . next_http_request_id . wrapping_add ( 1 ) ;
@@ -421,7 +426,6 @@ impl LocalSessionWorker {
421426 async fn establish_request_wise_channel (
422427 & mut self ,
423428 ) -> Result < StreamableHttpMessageReceiver , SessionError > {
424- self . tx_router . retain ( |_, rw| !rw. tx . tx . is_closed ( ) ) ;
425429 let http_request_id = self . next_http_request_id ( ) ;
426430 let ( tx, rx) = tokio:: sync:: mpsc:: channel ( self . session_config . channel_capacity ) ;
427431 let starting_index = usize:: from ( self . session_config . sse_retry . is_some ( ) ) ;
@@ -430,6 +434,7 @@ impl LocalSessionWorker {
430434 HttpRequestWise {
431435 resources : Default :: default ( ) ,
432436 tx : CachedTx :: new ( tx, Some ( http_request_id) , starting_index) ,
437+ completed_at : None ,
433438 } ,
434439 ) ;
435440 tracing:: debug!( http_request_id, "establish new request wise channel" ) ;
@@ -540,29 +545,25 @@ impl LocalSessionWorker {
540545
541546 match last_event_id. http_request_id {
542547 Some ( http_request_id) => {
543- if let Some ( request_wise) = self . tx_router . get_mut ( & http_request_id) {
544- let was_completed = request_wise. tx . tx . is_closed ( ) ;
545- let ( tx, rx) = tokio:: sync:: mpsc:: channel ( self . session_config . channel_capacity ) ;
546- request_wise. tx . tx = tx;
547- let index = last_event_id. index ;
548- request_wise. tx . sync ( index) . await ?;
549- if was_completed {
550- // Close the sender after replaying so the stream ends
551- // instead of hanging indefinitely.
552- let ( closed_tx, _) = tokio:: sync:: mpsc:: channel ( 1 ) ;
553- request_wise. tx . tx = closed_tx;
554- }
555- Ok ( StreamableHttpMessageReceiver {
556- http_request_id : Some ( http_request_id) ,
557- inner : rx,
558- } )
559- } else {
560- tracing:: debug!(
561- http_request_id,
562- "Request-wise channel not found, falling back to common channel"
563- ) ;
564- self . resume_or_shadow_common ( last_event_id. index ) . await
548+ let request_wise = self
549+ . tx_router
550+ . get_mut ( & http_request_id)
551+ . ok_or ( SessionError :: ChannelClosed ( Some ( http_request_id) ) ) ?;
552+ let is_completed = request_wise. completed_at . is_some ( ) ;
553+ let ( tx, rx) = tokio:: sync:: mpsc:: channel ( self . session_config . channel_capacity ) ;
554+ request_wise. tx . tx = tx;
555+ let index = last_event_id. index ;
556+ request_wise. tx . sync ( index) . await ?;
557+ if is_completed {
558+ // Drop the sender after replaying so the stream ends
559+ // instead of hanging indefinitely.
560+ let ( closed_tx, _) = tokio:: sync:: mpsc:: channel ( 1 ) ;
561+ request_wise. tx . tx = closed_tx;
565562 }
563+ Ok ( StreamableHttpMessageReceiver {
564+ http_request_id : Some ( http_request_id) ,
565+ inner : rx,
566+ } )
566567 }
567568 None => self . resume_or_shadow_common ( last_event_id. index ) . await ,
568569 }
@@ -972,6 +973,7 @@ impl Worker for LocalSessionWorker {
972973 let ct = context. cancellation_token . clone ( ) ;
973974 let keep_alive = self . session_config . keep_alive . unwrap_or ( Duration :: MAX ) ;
974975 loop {
976+ self . evict_expired_channels ( ) ;
975977 let keep_alive_timeout = tokio:: time:: sleep ( keep_alive) ;
976978 let event = tokio:: select! {
977979 event = self . event_rx. recv( ) => {
@@ -1098,12 +1100,17 @@ pub struct SessionConfig {
10981100 /// stream-identifying event ID to each request-wise SSE stream.
10991101 /// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
11001102 pub sse_retry : Option < Duration > ,
1103+ /// How long to retain completed request-wise channel caches for late
1104+ /// resume requests. After this duration, completed entries are evicted
1105+ /// and resume will return an error. Default is 60 seconds.
1106+ pub completed_cache_ttl : Duration ,
11011107}
11021108
11031109impl SessionConfig {
11041110 pub const DEFAULT_CHANNEL_CAPACITY : usize = 16 ;
11051111 pub const DEFAULT_KEEP_ALIVE : Duration = Duration :: from_secs ( 300 ) ;
11061112 pub const DEFAULT_SSE_RETRY : Duration = Duration :: from_secs ( 3 ) ;
1113+ pub const DEFAULT_COMPLETED_CACHE_TTL : Duration = Duration :: from_secs ( 60 ) ;
11071114}
11081115
11091116impl Default for SessionConfig {
@@ -1112,6 +1119,7 @@ impl Default for SessionConfig {
11121119 channel_capacity : Self :: DEFAULT_CHANNEL_CAPACITY ,
11131120 keep_alive : Some ( Self :: DEFAULT_KEEP_ALIVE ) ,
11141121 sse_retry : Some ( Self :: DEFAULT_SSE_RETRY ) ,
1122+ completed_cache_ttl : Self :: DEFAULT_COMPLETED_CACHE_TTL ,
11151123 }
11161124 }
11171125}
0 commit comments