@@ -11,7 +11,7 @@ use tracing::debug;
1111use super :: common:: client_side_sse:: { ExponentialBackoff , SseRetryPolicy , SseStreamReconnect } ;
1212use crate :: {
1313 RoleClient ,
14- model:: { ClientJsonRpcMessage , ServerJsonRpcMessage } ,
14+ model:: { ClientJsonRpcMessage , ServerJsonRpcMessage , ServerResult } ,
1515 transport:: {
1616 common:: client_side_sse:: SseAutoReconnectStream ,
1717 worker:: { Worker , WorkerQuitReason , WorkerSendRequest , WorkerTransport } ,
@@ -184,13 +184,15 @@ pub trait StreamableHttpClient: Clone + Send + 'static {
184184 uri : Arc < str > ,
185185 session_id : Arc < str > ,
186186 auth_header : Option < String > ,
187+ custom_headers : HashMap < HeaderName , HeaderValue > ,
187188 ) -> impl Future < Output = Result < ( ) , StreamableHttpError < Self :: Error > > > + Send + ' _ ;
188189 fn get_stream (
189190 & self ,
190191 uri : Arc < str > ,
191192 session_id : Arc < str > ,
192193 last_event_id : Option < String > ,
193194 auth_header : Option < String > ,
195+ custom_headers : HashMap < HeaderName , HeaderValue > ,
194196 ) -> impl Future <
195197 Output = Result <
196198 BoxStream < ' static , Result < Sse , SseError > > ,
@@ -210,6 +212,7 @@ struct StreamableHttpClientReconnect<C> {
210212 pub session_id : Arc < str > ,
211213 pub uri : Arc < str > ,
212214 pub auth_header : Option < String > ,
215+ pub custom_headers : HashMap < HeaderName , HeaderValue > ,
213216}
214217
215218impl < C : StreamableHttpClient > SseStreamReconnect for StreamableHttpClientReconnect < C > {
@@ -220,15 +223,25 @@ impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconne
220223 let uri = self . uri . clone ( ) ;
221224 let session_id = self . session_id . clone ( ) ;
222225 let auth_header = self . auth_header . clone ( ) ;
226+ let custom_headers = self . custom_headers . clone ( ) ;
223227 let last_event_id = last_event_id. map ( |s| s. to_owned ( ) ) ;
224228 Box :: pin ( async move {
225229 client
226- . get_stream ( uri, session_id, last_event_id, auth_header)
230+ . get_stream ( uri, session_id, last_event_id, auth_header, custom_headers )
227231 . await
228232 } )
229233 }
230234}
231235
236+ /// Info retained for cleaning up the session when the worker exits.
237+ struct SessionCleanupInfo < C > {
238+ client : C ,
239+ uri : Arc < str > ,
240+ session_id : Arc < str > ,
241+ auth_header : Option < String > ,
242+ protocol_headers : HashMap < HeaderName , HeaderValue > ,
243+ }
244+
232245#[ derive( Debug , Clone , Default ) ]
233246pub struct StreamableHttpClientWorker < C : StreamableHttpClient > {
234247 pub client : C ,
@@ -357,14 +370,29 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
357370 }
358371 None
359372 } ;
373+ // Extract the negotiated protocol version from the init response
374+ // and build a custom headers map that includes MCP-Protocol-Version
375+ // for all subsequent HTTP requests (per MCP 2025-06-18 spec).
376+ let protocol_headers = {
377+ let mut headers = config. custom_headers . clone ( ) ;
378+ if let ServerJsonRpcMessage :: Response ( response) = & message {
379+ if let ServerResult :: InitializeResult ( init_result) = & response. result {
380+ if let Ok ( hv) = HeaderValue :: from_str ( init_result. protocol_version . as_str ( ) ) {
381+ // HeaderName::from_static requires lowercase
382+ headers. insert ( HeaderName :: from_static ( "mcp-protocol-version" ) , hv) ;
383+ }
384+ }
385+ }
386+ headers
387+ } ;
388+
360389 // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
361- let session_cleanup_info = session_id. as_ref ( ) . map ( |sid| {
362- (
363- self . client . clone ( ) ,
364- config. uri . clone ( ) ,
365- sid. clone ( ) ,
366- config. auth_header . clone ( ) ,
367- )
390+ let session_cleanup_info = session_id. as_ref ( ) . map ( |sid| SessionCleanupInfo {
391+ client : self . client . clone ( ) ,
392+ uri : config. uri . clone ( ) ,
393+ session_id : sid. clone ( ) ,
394+ auth_header : config. auth_header . clone ( ) ,
395+ protocol_headers : protocol_headers. clone ( ) ,
368396 } ) ;
369397
370398 context. send_to_handler ( message) . await ?;
@@ -376,7 +404,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
376404 initialized_notification. message ,
377405 session_id. clone ( ) ,
378406 config. auth_header . clone ( ) ,
379- config . custom_headers . clone ( ) ,
407+ protocol_headers . clone ( ) ,
380408 )
381409 . await
382410 . map_err ( WorkerQuitReason :: fatal_context (
@@ -404,10 +432,17 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
404432 let transport_task_ct = transport_task_ct. clone ( ) ;
405433 let config_uri = config. uri . clone ( ) ;
406434 let config_auth_header = config. auth_header . clone ( ) ;
435+ let spawn_headers = protocol_headers. clone ( ) ;
407436
408437 streams. spawn ( async move {
409438 match client
410- . get_stream ( uri. clone ( ) , session_id. clone ( ) , None , auth_header. clone ( ) )
439+ . get_stream (
440+ uri. clone ( ) ,
441+ session_id. clone ( ) ,
442+ None ,
443+ auth_header. clone ( ) ,
444+ spawn_headers. clone ( ) ,
445+ )
411446 . await
412447 {
413448 Ok ( stream) => {
@@ -418,6 +453,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
418453 session_id : session_id. clone ( ) ,
419454 uri : config_uri,
420455 auth_header : config_auth_header,
456+ custom_headers : spawn_headers,
421457 } ,
422458 retry_config,
423459 ) ;
@@ -482,7 +518,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
482518 message,
483519 session_id. clone ( ) ,
484520 config. auth_header . clone ( ) ,
485- config . custom_headers . clone ( ) ,
521+ protocol_headers . clone ( ) ,
486522 )
487523 . await ;
488524 let send_result = match response {
@@ -504,6 +540,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
504540 session_id : session_id. clone ( ) ,
505541 uri : config. uri . clone ( ) ,
506542 auth_header : config. auth_header . clone ( ) ,
543+ custom_headers : protocol_headers. clone ( ) ,
507544 } ,
508545 self . config . retry_config . clone ( ) ,
509546 ) ;
@@ -550,32 +587,41 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
550587
551588 // Cleanup session before returning (ensures close() waits for session deletion)
552589 // Use a timeout to prevent indefinite hangs if the server is unresponsive
553- if let Some ( ( client , url , session_id , auth_header ) ) = session_cleanup_info {
590+ if let Some ( cleanup ) = session_cleanup_info {
554591 const SESSION_CLEANUP_TIMEOUT : std:: time:: Duration = std:: time:: Duration :: from_secs ( 5 ) ;
592+ let cleanup_session_id = cleanup. session_id . clone ( ) ;
555593 match tokio:: time:: timeout (
556594 SESSION_CLEANUP_TIMEOUT ,
557- client. delete_session ( url, session_id. clone ( ) , auth_header) ,
595+ cleanup. client . delete_session (
596+ cleanup. uri ,
597+ cleanup. session_id ,
598+ cleanup. auth_header ,
599+ cleanup. protocol_headers ,
600+ ) ,
558601 )
559602 . await
560603 {
561604 Ok ( Ok ( _) ) => {
562- tracing:: info!( session_id = session_id. as_ref( ) , "delete session success" )
605+ tracing:: info!(
606+ session_id = cleanup_session_id. as_ref( ) ,
607+ "delete session success"
608+ )
563609 }
564610 Ok ( Err ( StreamableHttpError :: ServerDoesNotSupportDeleteSession ) ) => {
565611 tracing:: info!(
566- session_id = session_id . as_ref( ) ,
612+ session_id = cleanup_session_id . as_ref( ) ,
567613 "server doesn't support delete session"
568614 )
569615 }
570616 Ok ( Err ( e) ) => {
571617 tracing:: error!(
572- session_id = session_id . as_ref( ) ,
618+ session_id = cleanup_session_id . as_ref( ) ,
573619 "fail to delete session: {e}"
574620 ) ;
575621 }
576622 Err ( _elapsed) => {
577623 tracing:: warn!(
578- session_id = session_id . as_ref( ) ,
624+ session_id = cleanup_session_id . as_ref( ) ,
579625 "session cleanup timed out after {:?}" ,
580626 SESSION_CLEANUP_TIMEOUT
581627 ) ;
@@ -652,6 +698,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
652698/// _uri: Arc<str>,
653699/// _session_id: Arc<str>,
654700/// _auth_header: Option<String>,
701+ /// _custom_headers: HashMap<HeaderName, HeaderValue>,
655702/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
656703/// todo!()
657704/// }
@@ -662,6 +709,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
662709/// _session_id: Arc<str>,
663710/// _last_event_id: Option<String>,
664711/// _auth_header: Option<String>,
712+ /// _custom_headers: HashMap<HeaderName, HeaderValue>,
665713/// ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
666714/// todo!()
667715/// }
@@ -737,6 +785,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
737785 /// _uri: Arc<str>,
738786 /// _session_id: Arc<str>,
739787 /// _auth_header: Option<String>,
788+ /// _custom_headers: HashMap<HeaderName, HeaderValue>,
740789 /// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
741790 /// todo!()
742791 /// }
@@ -747,6 +796,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
747796 /// _session_id: Arc<str>,
748797 /// _last_event_id: Option<String>,
749798 /// _auth_header: Option<String>,
799+ /// _custom_headers: HashMap<HeaderName, HeaderValue>,
750800 /// ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
751801 /// todo!()
752802 /// }
0 commit comments