1- use std:: { borrow:: Cow , sync:: Arc , time:: Duration } ;
1+ use std:: { borrow:: Cow , collections :: HashMap , sync:: Arc , time:: Duration } ;
22
33use futures:: { Stream , StreamExt , future:: BoxFuture , stream:: BoxStream } ;
4+ use http:: { HeaderName , HeaderValue } ;
45pub use sse_stream:: Error as SseError ;
56use sse_stream:: Sse ;
67use thiserror:: Error ;
@@ -76,6 +77,8 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
7677 AuthRequired ( AuthRequiredError ) ,
7778 #[ error( "Insufficient scope" ) ]
7879 InsufficientScope ( InsufficientScopeError ) ,
80+ #[ error( "Header name '{0}' is reserved and conflicts with default headers" ) ]
81+ ReservedHeaderConflict ( String ) ,
7982}
8083
8184#[ derive( Debug , Clone , Error ) ]
@@ -173,6 +176,7 @@ pub trait StreamableHttpClient: Clone + Send + 'static {
173176 message : ClientJsonRpcMessage ,
174177 session_id : Option < Arc < str > > ,
175178 auth_header : Option < String > ,
179+ custom_headers : HashMap < HeaderName , HeaderValue > ,
176180 ) -> impl Future < Output = Result < StreamableHttpPostResponse , StreamableHttpError < Self :: Error > > >
177181 + Send
178182 + ' _ ;
@@ -324,6 +328,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
324328 initialize_request,
325329 None ,
326330 self . config . auth_header ,
331+ self . config . custom_headers ,
327332 )
328333 . await
329334 {
@@ -372,6 +377,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
372377 initialized_notification. message ,
373378 session_id. clone ( ) ,
374379 config. auth_header . clone ( ) ,
380+ config. custom_headers . clone ( ) ,
375381 )
376382 . await
377383 . map_err ( WorkerQuitReason :: fatal_context (
@@ -477,6 +483,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
477483 message,
478484 session_id. clone ( ) ,
479485 config. auth_header . clone ( ) ,
486+ config. custom_headers . clone ( ) ,
480487 )
481488 . await ;
482489 let send_result = match response {
@@ -609,8 +616,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
609616/// StreamableHttpClientTransportConfig
610617/// };
611618/// use std::sync::Arc;
619+ /// use std::collections::HashMap;
612620/// use futures::stream::BoxStream;
613621/// use rmcp::model::ClientJsonRpcMessage;
622+ /// use http::{HeaderName, HeaderValue};
614623/// use sse_stream::{Sse, Error as SseError};
615624///
616625/// #[derive(Clone)]
@@ -634,6 +643,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
634643/// _message: ClientJsonRpcMessage,
635644/// _session_id: Option<Arc<str>>,
636645/// _auth_header: Option<String>,
646+ /// _custom_headers: HashMap<HeaderName, HeaderValue>,
637647/// ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
638648/// todo!()
639649/// }
@@ -690,8 +700,10 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
690700 /// StreamableHttpClientTransportConfig
691701 /// };
692702 /// use std::sync::Arc;
703+ /// use std::collections::HashMap;
693704 /// use futures::stream::BoxStream;
694705 /// use rmcp::model::ClientJsonRpcMessage;
706+ /// use http::{HeaderName, HeaderValue};
695707 /// use sse_stream::{Sse, Error as SseError};
696708 ///
697709 /// // Define your custom client
@@ -716,6 +728,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
716728 /// _message: ClientJsonRpcMessage,
717729 /// _session_id: Option<Arc<str>>,
718730 /// _auth_header: Option<String>,
731+ /// _custom_headers: HashMap<HeaderName, HeaderValue>,
719732 /// ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
720733 /// todo!()
721734 /// }
@@ -759,6 +772,8 @@ pub struct StreamableHttpClientTransportConfig {
759772 pub allow_stateless : bool ,
760773 /// The value to send in the authorization header
761774 pub auth_header : Option < String > ,
775+ /// Custom HTTP headers to include with every request
776+ pub custom_headers : HashMap < HeaderName , HeaderValue > ,
762777}
763778
764779impl StreamableHttpClientTransportConfig {
@@ -779,6 +794,33 @@ impl StreamableHttpClientTransportConfig {
779794 self . auth_header = Some ( value. into ( ) ) ;
780795 self
781796 }
797+
798+ /// Set custom HTTP headers to include with every request
799+ ///
800+ /// # Arguments
801+ ///
802+ /// * `custom_headers` - A HashMap of header names to header values
803+ ///
804+ /// # Example
805+ ///
806+ /// ```rust,no_run
807+ /// use std::collections::HashMap;
808+ /// use http::{HeaderName, HeaderValue};
809+ /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
810+ ///
811+ /// let mut headers = HashMap::new();
812+ /// headers.insert(
813+ /// HeaderName::from_static("x-custom-header"),
814+ /// HeaderValue::from_static("custom-value")
815+ /// );
816+ ///
817+ /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
818+ /// .custom_headers(headers);
819+ /// ```
820+ pub fn custom_headers ( mut self , custom_headers : HashMap < HeaderName , HeaderValue > ) -> Self {
821+ self . custom_headers = custom_headers;
822+ self
823+ }
782824}
783825
784826impl Default for StreamableHttpClientTransportConfig {
@@ -789,6 +831,7 @@ impl Default for StreamableHttpClientTransportConfig {
789831 channel_buffer_capacity : 16 ,
790832 allow_stateless : true ,
791833 auth_header : None ,
834+ custom_headers : HashMap :: new ( ) ,
792835 }
793836 }
794837}
0 commit comments