Skip to content

Commit 3ea8c3c

Browse files
authored
feat: add configuration for transparent session re-init (#760)
* feat: add configuration for transparent session re-init * fix: in ci revert running tests without local until all tests pass * fix: pr comments * fix: documentation
1 parent 251ebec commit 3ea8c3c

File tree

3 files changed

+203
-95
lines changed

3 files changed

+203
-95
lines changed

crates/rmcp/Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,15 @@ path = "tests/test_sse_concurrent_streams.rs"
265265
name = "test_client_credentials"
266266
required-features = ["auth"]
267267
path = "tests/test_client_credentials.rs"
268+
269+
[[test]]
270+
name = "test_streamable_http_stale_session"
271+
required-features = [
272+
"server",
273+
"client",
274+
"transport-streamable-http-server",
275+
"transport-streamable-http-client",
276+
"transport-streamable-http-client-reqwest"
277+
]
278+
path = "tests/test_streamable_http_stale_session.rs"
279+

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 122 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -600,48 +600,51 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
600600
.await;
601601
let send_result = match response {
602602
Err(StreamableHttpError::SessionExpired) => {
603-
// The server discarded the session (HTTP 404). Perform a
604-
// fresh handshake once and replay the original message.
605-
tracing::info!(
606-
"session expired (HTTP 404), attempting transparent re-initialization"
607-
);
608-
match Self::perform_reinitialization(
609-
self.client.clone(),
610-
saved_init_request.clone(),
611-
config.uri.clone(),
612-
config.auth_header.clone(),
613-
config.custom_headers.clone(),
614-
)
615-
.await
616-
{
617-
Ok((new_session_id, new_protocol_headers)) => {
618-
// Old streams hold the stale session ID; abort them
619-
// so the new standalone SSE stream takes over.
620-
streams.abort_all();
603+
if !config.reinit_on_expired_session {
604+
Err(StreamableHttpError::SessionExpired)
605+
} else {
606+
// The server discarded the session (HTTP 404). Perform a
607+
// fresh handshake once and replay the original message.
608+
tracing::info!(
609+
"session expired (HTTP 404), attempting transparent re-initialization"
610+
);
611+
match Self::perform_reinitialization(
612+
self.client.clone(),
613+
saved_init_request.clone(),
614+
config.uri.clone(),
615+
config.auth_header.clone(),
616+
config.custom_headers.clone(),
617+
)
618+
.await
619+
{
620+
Ok((new_session_id, new_protocol_headers)) => {
621+
// Old streams hold the stale session ID; abort them
622+
// so the new standalone SSE stream takes over.
623+
streams.abort_all();
621624

622-
session_id = new_session_id;
623-
protocol_headers = new_protocol_headers;
624-
session_cleanup_info =
625-
session_id.as_ref().map(|sid| SessionCleanupInfo {
626-
client: self.client.clone(),
627-
uri: config.uri.clone(),
628-
session_id: sid.clone(),
629-
auth_header: config.auth_header.clone(),
630-
protocol_headers: protocol_headers.clone(),
631-
});
625+
session_id = new_session_id;
626+
protocol_headers = new_protocol_headers;
627+
session_cleanup_info =
628+
session_id.as_ref().map(|sid| SessionCleanupInfo {
629+
client: self.client.clone(),
630+
uri: config.uri.clone(),
631+
session_id: sid.clone(),
632+
auth_header: config.auth_header.clone(),
633+
protocol_headers: protocol_headers.clone(),
634+
});
632635

633-
if let Some(new_sid) = &session_id {
634-
let client = self.client.clone();
635-
let uri = config.uri.clone();
636-
let new_sid = new_sid.clone();
637-
let auth_header = config.auth_header.clone();
638-
let retry_config = self.config.retry_config.clone();
639-
let sse_tx = sse_worker_tx.clone();
640-
let task_ct = transport_task_ct.clone();
641-
let config_uri = config.uri.clone();
642-
let config_auth = config.auth_header.clone();
643-
let spawn_headers = protocol_headers.clone();
644-
streams.spawn(async move {
636+
if let Some(new_sid) = &session_id {
637+
let client = self.client.clone();
638+
let uri = config.uri.clone();
639+
let new_sid = new_sid.clone();
640+
let auth_header = config.auth_header.clone();
641+
let retry_config = self.config.retry_config.clone();
642+
let sse_tx = sse_worker_tx.clone();
643+
let task_ct = transport_task_ct.clone();
644+
let config_uri = config.uri.clone();
645+
let config_auth = config.auth_header.clone();
646+
let spawn_headers = protocol_headers.clone();
647+
streams.spawn(async move {
645648
match client
646649
.get_stream(
647650
uri,
@@ -686,69 +689,71 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
686689
}
687690
}
688691
});
689-
}
690-
691-
let retry_response = self
692-
.client
693-
.post_message(
694-
config.uri.clone(),
695-
message,
696-
session_id.clone(),
697-
config.auth_header.clone(),
698-
protocol_headers.clone(),
699-
)
700-
.await;
701-
match retry_response {
702-
Err(e) => Err(e),
703-
Ok(StreamableHttpPostResponse::Accepted) => {
704-
tracing::trace!(
705-
"client message accepted after re-init"
706-
);
707-
Ok(())
708-
}
709-
Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
710-
context.send_to_handler(msg).await?;
711-
Ok(())
712692
}
713-
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
714-
if let Some(sid) = &session_id {
715-
let sse_stream = SseAutoReconnectStream::new(
716-
stream,
717-
StreamableHttpClientReconnect {
718-
client: self.client.clone(),
719-
session_id: sid.clone(),
720-
uri: config.uri.clone(),
721-
auth_header: config.auth_header.clone(),
722-
custom_headers: protocol_headers.clone(),
723-
},
724-
self.config.retry_config.clone(),
693+
694+
let retry_response = self
695+
.client
696+
.post_message(
697+
config.uri.clone(),
698+
message,
699+
session_id.clone(),
700+
config.auth_header.clone(),
701+
protocol_headers.clone(),
702+
)
703+
.await;
704+
match retry_response {
705+
Err(e) => Err(e),
706+
Ok(StreamableHttpPostResponse::Accepted) => {
707+
tracing::trace!(
708+
"client message accepted after re-init"
725709
);
726-
streams.spawn(Self::execute_sse_stream(
727-
sse_stream,
728-
sse_worker_tx.clone(),
729-
true,
730-
transport_task_ct.child_token(),
731-
));
732-
} else {
733-
let sse_stream =
710+
Ok(())
711+
}
712+
Ok(StreamableHttpPostResponse::Json(msg, ..)) => {
713+
context.send_to_handler(msg).await?;
714+
Ok(())
715+
}
716+
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
717+
if let Some(sid) = &session_id {
718+
let sse_stream = SseAutoReconnectStream::new(
719+
stream,
720+
StreamableHttpClientReconnect {
721+
client: self.client.clone(),
722+
session_id: sid.clone(),
723+
uri: config.uri.clone(),
724+
auth_header: config.auth_header.clone(),
725+
custom_headers: protocol_headers
726+
.clone(),
727+
},
728+
self.config.retry_config.clone(),
729+
);
730+
streams.spawn(Self::execute_sse_stream(
731+
sse_stream,
732+
sse_worker_tx.clone(),
733+
true,
734+
transport_task_ct.child_token(),
735+
));
736+
} else {
737+
let sse_stream =
734738
SseAutoReconnectStream::never_reconnect(
735739
stream,
736740
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
737741
);
738-
streams.spawn(Self::execute_sse_stream(
739-
sse_stream,
740-
sse_worker_tx.clone(),
741-
true,
742-
transport_task_ct.child_token(),
743-
));
742+
streams.spawn(Self::execute_sse_stream(
743+
sse_stream,
744+
sse_worker_tx.clone(),
745+
true,
746+
transport_task_ct.child_token(),
747+
));
748+
}
749+
tracing::trace!("got new sse stream after re-init");
750+
Ok(())
744751
}
745-
tracing::trace!("got new sse stream after re-init");
746-
Ok(())
747752
}
748753
}
754+
Err(reinit_err) => Err(reinit_err),
749755
}
750-
Err(reinit_err) => Err(reinit_err),
751-
}
756+
} // else enable_reinit_on_expired_session
752757
}
753758
Err(e) => Err(e),
754759
Ok(StreamableHttpPostResponse::Accepted) => {
@@ -1051,6 +1056,16 @@ pub struct StreamableHttpClientTransportConfig {
10511056
pub auth_header: Option<String>,
10521057
/// Custom HTTP headers to include with every request
10531058
pub custom_headers: HashMap<HeaderName, HeaderValue>,
1059+
/// Enables transparent recovery when the server reports an expired session (`HTTP 404`).
1060+
///
1061+
/// When enabled, the transport performs one automatic recovery attempt:
1062+
/// 1. Replays the original `initialize` handshake to create a new session.
1063+
/// 2. Re-establishes streaming state for that session.
1064+
/// 3. Retries the in-flight request that failed with `SessionExpired`.
1065+
///
1066+
/// This recovery is best-effort and bounded to a single attempt. If recovery fails,
1067+
/// the original failure path is preserved and the error is returned to the caller.
1068+
pub reinit_on_expired_session: bool,
10541069
}
10551070

10561071
impl StreamableHttpClientTransportConfig {
@@ -1098,6 +1113,19 @@ impl StreamableHttpClientTransportConfig {
10981113
self.custom_headers = custom_headers;
10991114
self
11001115
}
1116+
1117+
/// Set whether the transport should attempt transparent re-initialization on session expiration
1118+
/// See [`Self::reinit_on_expired_session`] for details.
1119+
/// # Example
1120+
/// ```rust,no_run
1121+
/// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
1122+
/// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
1123+
/// .reinit_on_expired_session(true);
1124+
/// ```
1125+
pub fn reinit_on_expired_session(mut self, enable: bool) -> Self {
1126+
self.reinit_on_expired_session = enable;
1127+
self
1128+
}
11011129
}
11021130

11031131
impl Default for StreamableHttpClientTransportConfig {
@@ -1109,6 +1137,7 @@ impl Default for StreamableHttpClientTransportConfig {
11091137
allow_stateless: true,
11101138
auth_header: None,
11111139
custom_headers: HashMap::new(),
1140+
reinit_on_expired_session: true,
11121141
}
11131142
}
11141143
}

0 commit comments

Comments
 (0)