Skip to content

Commit 016b7d3

Browse files
authored
feat: add support for custom HTTP headers in StreamableHttpClient (#655)
* feat: add support for custom HTTP headers in StreamableHttpClient * feat: implement reserved header checks for custom HTTP headers in StreamableHttpClient
1 parent 70f6380 commit 016b7d3

7 files changed

Lines changed: 622 additions & 8 deletions

File tree

crates/rmcp/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,13 @@ path = "tests/test_sampling.rs"
230230
name = "test_close_connection"
231231
required-features = ["server", "client"]
232232
path = "tests/test_close_connection.rs"
233+
234+
[[test]]
235+
name = "test_custom_headers"
236+
required-features = [
237+
"client",
238+
"server",
239+
"transport-streamable-http-client-reqwest",
240+
"transport-streamable-http-server",
241+
]
242+
path = "tests/test_custom_headers.rs"

crates/rmcp/src/transport/auth.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use thiserror::Error;
1616
use tokio::sync::{Mutex, RwLock};
1717
use tracing::{debug, error, warn};
1818

19+
use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION;
20+
1921
const DEFAULT_EXCHANGE_URL: &str = "http://localhost";
2022

2123
/// Stored credentials for OAuth2 authorization
@@ -1068,7 +1070,7 @@ impl AuthorizationManager {
10681070
let response = match self
10691071
.http_client
10701072
.get(discovery_url.clone())
1071-
.header("MCP-Protocol-Version", "2024-11-05")
1073+
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
10721074
.send()
10731075
.await
10741076
{
@@ -1188,7 +1190,7 @@ impl AuthorizationManager {
11881190
let response = match self
11891191
.http_client
11901192
.get(url.clone())
1191-
.header("MCP-Protocol-Version", "2024-11-05")
1193+
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
11921194
.send()
11931195
.await
11941196
{
@@ -1241,7 +1243,7 @@ impl AuthorizationManager {
12411243
let response = match self
12421244
.http_client
12431245
.get(resource_metadata_url.clone())
1244-
.header("MCP-Protocol-Version", "2024-11-05")
1246+
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
12451247
.send()
12461248
.await
12471249
{

crates/rmcp/src/transport/common/auth/streamable_http_client.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
use std::collections::HashMap;
2+
3+
use http::{HeaderName, HeaderValue};
4+
15
use crate::transport::{
26
auth::AuthClient,
37
streamable_http_client::{StreamableHttpClient, StreamableHttpError},
@@ -47,6 +51,7 @@ where
4751
message: crate::model::ClientJsonRpcMessage,
4852
session_id: Option<std::sync::Arc<str>>,
4953
mut auth_token: Option<String>,
54+
custom_headers: HashMap<HeaderName, HeaderValue>,
5055
) -> Result<
5156
crate::transport::streamable_http_client::StreamableHttpPostResponse,
5257
StreamableHttpError<Self::Error>,
@@ -55,7 +60,7 @@ where
5560
auth_token = Some(self.get_access_token().await?);
5661
}
5762
self.http_client
58-
.post_message(uri, message, session_id, auth_token)
63+
.post_message(uri, message, session_id, auth_token, custom_headers)
5964
.await
6065
}
6166
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
22
pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
3+
pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version";
34
pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
45
pub const JSON_MIME_TYPE: &str = "application/json";

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
use std::{borrow::Cow, sync::Arc};
1+
use std::{borrow::Cow, collections::HashMap, sync::Arc};
22

33
use futures::{StreamExt, stream::BoxStream};
4-
use http::header::WWW_AUTHENTICATE;
4+
use http::{HeaderName, HeaderValue, header::WWW_AUTHENTICATE};
55
use reqwest::header::ACCEPT;
66
use sse_stream::{Sse, SseStream};
77

88
use crate::{
99
model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
1010
transport::{
1111
common::http_header::{
12-
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
12+
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
13+
HEADER_SESSION_ID, JSON_MIME_TYPE,
1314
},
1415
streamable_http_client::*,
1516
},
@@ -94,13 +95,34 @@ impl StreamableHttpClient for reqwest::Client {
9495
message: ClientJsonRpcMessage,
9596
session_id: Option<Arc<str>>,
9697
auth_token: Option<String>,
98+
custom_headers: HashMap<HeaderName, HeaderValue>,
9799
) -> Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
98100
let mut request = self
99101
.post(uri.as_ref())
100102
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
101103
if let Some(auth_header) = auth_token {
102104
request = request.bearer_auth(auth_header);
103105
}
106+
107+
// Apply custom headers
108+
let reserved_headers = [
109+
ACCEPT.as_str(),
110+
HEADER_SESSION_ID,
111+
HEADER_MCP_PROTOCOL_VERSION,
112+
HEADER_LAST_EVENT_ID,
113+
];
114+
for (name, value) in custom_headers {
115+
if reserved_headers
116+
.iter()
117+
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
118+
{
119+
return Err(StreamableHttpError::ReservedHeaderConflict(
120+
name.to_string(),
121+
));
122+
}
123+
124+
request = request.header(name, value);
125+
}
104126
if let Some(session_id) = session_id {
105127
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
106128
}

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use std::{borrow::Cow, sync::Arc, time::Duration};
1+
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
22

33
use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
4+
use http::{HeaderName, HeaderValue};
45
pub use sse_stream::Error as SseError;
56
use sse_stream::Sse;
67
use thiserror::Error;
@@ -76,6 +77,8 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
7677
AuthRequired(AuthRequiredError),
7778
#[error("Insufficient scope")]
7879
InsufficientScope(InsufficientScopeError),
80+
#[error("Header name '{0}' is reserved and conflicts with default headers")]
81+
ReservedHeaderConflict(String),
7982
}
8083

8184
#[derive(Debug, Clone, Error)]
@@ -173,6 +176,7 @@ pub trait StreamableHttpClient: Clone + Send + 'static {
173176
message: ClientJsonRpcMessage,
174177
session_id: Option<Arc<str>>,
175178
auth_header: Option<String>,
179+
custom_headers: HashMap<HeaderName, HeaderValue>,
176180
) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
177181
+ Send
178182
+ '_;
@@ -324,6 +328,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
324328
initialize_request,
325329
None,
326330
self.config.auth_header,
331+
self.config.custom_headers,
327332
)
328333
.await
329334
{
@@ -372,6 +377,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
372377
initialized_notification.message,
373378
session_id.clone(),
374379
config.auth_header.clone(),
380+
config.custom_headers.clone(),
375381
)
376382
.await
377383
.map_err(WorkerQuitReason::fatal_context(
@@ -477,6 +483,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
477483
message,
478484
session_id.clone(),
479485
config.auth_header.clone(),
486+
config.custom_headers.clone(),
480487
)
481488
.await;
482489
let send_result = match response {
@@ -609,8 +616,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
609616
/// StreamableHttpClientTransportConfig
610617
/// };
611618
/// use std::sync::Arc;
619+
/// use std::collections::HashMap;
612620
/// use futures::stream::BoxStream;
613621
/// use rmcp::model::ClientJsonRpcMessage;
622+
/// use http::{HeaderName, HeaderValue};
614623
/// use sse_stream::{Sse, Error as SseError};
615624
///
616625
/// #[derive(Clone)]
@@ -634,6 +643,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
634643
/// _message: ClientJsonRpcMessage,
635644
/// _session_id: Option<Arc<str>>,
636645
/// _auth_header: Option<String>,
646+
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
637647
/// ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
638648
/// todo!()
639649
/// }
@@ -690,8 +700,10 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
690700
/// StreamableHttpClientTransportConfig
691701
/// };
692702
/// use std::sync::Arc;
703+
/// use std::collections::HashMap;
693704
/// use futures::stream::BoxStream;
694705
/// use rmcp::model::ClientJsonRpcMessage;
706+
/// use http::{HeaderName, HeaderValue};
695707
/// use sse_stream::{Sse, Error as SseError};
696708
///
697709
/// // Define your custom client
@@ -716,6 +728,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
716728
/// _message: ClientJsonRpcMessage,
717729
/// _session_id: Option<Arc<str>>,
718730
/// _auth_header: Option<String>,
731+
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
719732
/// ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
720733
/// todo!()
721734
/// }
@@ -759,6 +772,8 @@ pub struct StreamableHttpClientTransportConfig {
759772
pub allow_stateless: bool,
760773
/// The value to send in the authorization header
761774
pub auth_header: Option<String>,
775+
/// Custom HTTP headers to include with every request
776+
pub custom_headers: HashMap<HeaderName, HeaderValue>,
762777
}
763778

764779
impl StreamableHttpClientTransportConfig {
@@ -779,6 +794,33 @@ impl StreamableHttpClientTransportConfig {
779794
self.auth_header = Some(value.into());
780795
self
781796
}
797+
798+
/// Set custom HTTP headers to include with every request
799+
///
800+
/// # Arguments
801+
///
802+
/// * `custom_headers` - A HashMap of header names to header values
803+
///
804+
/// # Example
805+
///
806+
/// ```rust,no_run
807+
/// use std::collections::HashMap;
808+
/// use http::{HeaderName, HeaderValue};
809+
/// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
810+
///
811+
/// let mut headers = HashMap::new();
812+
/// headers.insert(
813+
/// HeaderName::from_static("x-custom-header"),
814+
/// HeaderValue::from_static("custom-value")
815+
/// );
816+
///
817+
/// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
818+
/// .custom_headers(headers);
819+
/// ```
820+
pub fn custom_headers(mut self, custom_headers: HashMap<HeaderName, HeaderValue>) -> Self {
821+
self.custom_headers = custom_headers;
822+
self
823+
}
782824
}
783825

784826
impl Default for StreamableHttpClientTransportConfig {
@@ -789,6 +831,7 @@ impl Default for StreamableHttpClientTransportConfig {
789831
channel_buffer_capacity: 16,
790832
allow_stateless: true,
791833
auth_header: None,
834+
custom_headers: HashMap::new(),
792835
}
793836
}
794837
}

0 commit comments

Comments
 (0)