Skip to content

Commit 5fa012d

Browse files
authored
feat: send and validate MCP-Protocol-Version header (#675)
1 parent 91e208e commit 5fa012d

6 files changed

Lines changed: 541 additions & 60 deletions

File tree

crates/rmcp/src/model.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@ impl ProtocolVersion {
155155
pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05"));
156156
// Keep LATEST at 2025-03-26 until full 2025-06-18 compliance and automated testing are in place.
157157
pub const LATEST: Self = Self::V_2025_03_26;
158+
159+
/// All protocol versions known to this SDK.
160+
pub const KNOWN_VERSIONS: &[Self] =
161+
&[Self::V_2024_11_05, Self::V_2025_03_26, Self::V_2025_06_18];
162+
163+
/// Returns the string representation of this protocol version.
164+
pub fn as_str(&self) -> &str {
165+
&self.0
166+
}
158167
}
159168

160169
impl Serialize for ProtocolVersion {

crates/rmcp/src/transport/common/auth/streamable_http_client.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ where
1717
uri: std::sync::Arc<str>,
1818
session_id: std::sync::Arc<str>,
1919
mut auth_token: Option<String>,
20+
custom_headers: HashMap<HeaderName, HeaderValue>,
2021
) -> Result<(), crate::transport::streamable_http_client::StreamableHttpError<Self::Error>>
2122
{
2223
if auth_token.is_none() {
2324
auth_token = Some(self.get_access_token().await?);
2425
}
2526
self.http_client
26-
.delete_session(uri, session_id, auth_token)
27+
.delete_session(uri, session_id, auth_token, custom_headers)
2728
.await
2829
}
2930

@@ -33,6 +34,7 @@ where
3334
session_id: std::sync::Arc<str>,
3435
last_event_id: Option<String>,
3536
mut auth_token: Option<String>,
37+
custom_headers: HashMap<HeaderName, HeaderValue>,
3638
) -> Result<
3739
futures::stream::BoxStream<'static, Result<sse_stream::Sse, sse_stream::Error>>,
3840
crate::transport::streamable_http_client::StreamableHttpError<Self::Error>,
@@ -41,7 +43,7 @@ where
4143
auth_token = Some(self.get_access_token().await?);
4244
}
4345
self.http_client
44-
.get_stream(uri, session_id, last_event_id, auth_token)
46+
.get_stream(uri, session_id, last_event_id, auth_token, custom_headers)
4547
.await
4648
}
4749

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,43 @@ impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
2222
}
2323
}
2424

25+
/// Reserved headers that must not be overridden by user-supplied custom headers.
26+
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
27+
/// injects it after initialization.
28+
const RESERVED_HEADERS: &[&str] = &[
29+
"accept",
30+
HEADER_SESSION_ID,
31+
HEADER_MCP_PROTOCOL_VERSION,
32+
HEADER_LAST_EVENT_ID,
33+
];
34+
35+
/// Applies custom headers to a request builder, rejecting reserved headers
36+
/// except `MCP-Protocol-Version` (which the worker injects after init).
37+
fn apply_custom_headers(
38+
mut builder: reqwest::RequestBuilder,
39+
custom_headers: HashMap<HeaderName, HeaderValue>,
40+
) -> Result<reqwest::RequestBuilder, StreamableHttpError<reqwest::Error>> {
41+
for (name, value) in custom_headers {
42+
if RESERVED_HEADERS
43+
.iter()
44+
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
45+
{
46+
if name
47+
.as_str()
48+
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
49+
{
50+
builder = builder.header(name, value);
51+
continue;
52+
}
53+
return Err(StreamableHttpError::ReservedHeaderConflict(
54+
name.to_string(),
55+
));
56+
}
57+
builder = builder.header(name, value);
58+
}
59+
Ok(builder)
60+
}
61+
2562
impl StreamableHttpClient for reqwest::Client {
2663
type Error = reqwest::Error;
2764

@@ -31,6 +68,7 @@ impl StreamableHttpClient for reqwest::Client {
3168
session_id: Arc<str>,
3269
last_event_id: Option<String>,
3370
auth_token: Option<String>,
71+
custom_headers: HashMap<HeaderName, HeaderValue>,
3472
) -> Result<BoxStream<'static, Result<Sse, SseError>>, StreamableHttpError<Self::Error>> {
3573
let mut request_builder = self
3674
.get(uri.as_ref())
@@ -42,6 +80,7 @@ impl StreamableHttpClient for reqwest::Client {
4280
if let Some(auth_header) = auth_token {
4381
request_builder = request_builder.bearer_auth(auth_header);
4482
}
83+
request_builder = apply_custom_headers(request_builder, custom_headers)?;
4584
let response = request_builder.send().await?;
4685
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
4786
return Err(StreamableHttpError::ServerDoesNotSupportSse);
@@ -70,15 +109,15 @@ impl StreamableHttpClient for reqwest::Client {
70109
uri: Arc<str>,
71110
session: Arc<str>,
72111
auth_token: Option<String>,
112+
custom_headers: HashMap<HeaderName, HeaderValue>,
73113
) -> Result<(), StreamableHttpError<Self::Error>> {
74114
let mut request_builder = self.delete(uri.as_ref());
75115
if let Some(auth_header) = auth_token {
76116
request_builder = request_builder.bearer_auth(auth_header);
77117
}
78-
let response = request_builder
79-
.header(HEADER_SESSION_ID, session.as_ref())
80-
.send()
81-
.await?;
118+
request_builder = request_builder.header(HEADER_SESSION_ID, session.as_ref());
119+
request_builder = apply_custom_headers(request_builder, custom_headers)?;
120+
let response = request_builder.send().await?;
82121

83122
// if method no allowed
84123
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
@@ -104,25 +143,7 @@ impl StreamableHttpClient for reqwest::Client {
104143
request = request.bearer_auth(auth_header);
105144
}
106145

107-
// Apply custom headers
108-
let reserved_headers = [
109-
ACCEPT.as_str(),
110-
HEADER_SESSION_ID,
111-
HEADER_MCP_PROTOCOL_VERSION,
112-
HEADER_LAST_EVENT_ID,
113-
];
114-
for (name, value) in custom_headers {
115-
if reserved_headers
116-
.iter()
117-
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
118-
{
119-
return Err(StreamableHttpError::ReservedHeaderConflict(
120-
name.to_string(),
121-
));
122-
}
123-
124-
request = request.header(name, value);
125-
}
146+
request = apply_custom_headers(request, custom_headers)?;
126147
if let Some(session_id) = session_id {
127148
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
128149
}

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tracing::debug;
1111
use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
1212
use crate::{
1313
RoleClient,
14-
model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
14+
model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
1515
transport::{
1616
common::client_side_sse::SseAutoReconnectStream,
1717
worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
@@ -184,13 +184,15 @@ pub trait StreamableHttpClient: Clone + Send + 'static {
184184
uri: Arc<str>,
185185
session_id: Arc<str>,
186186
auth_header: Option<String>,
187+
custom_headers: HashMap<HeaderName, HeaderValue>,
187188
) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
188189
fn get_stream(
189190
&self,
190191
uri: Arc<str>,
191192
session_id: Arc<str>,
192193
last_event_id: Option<String>,
193194
auth_header: Option<String>,
195+
custom_headers: HashMap<HeaderName, HeaderValue>,
194196
) -> impl Future<
195197
Output = Result<
196198
BoxStream<'static, Result<Sse, SseError>>,
@@ -210,6 +212,7 @@ struct StreamableHttpClientReconnect<C> {
210212
pub session_id: Arc<str>,
211213
pub uri: Arc<str>,
212214
pub auth_header: Option<String>,
215+
pub custom_headers: HashMap<HeaderName, HeaderValue>,
213216
}
214217

215218
impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
@@ -220,15 +223,25 @@ impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconne
220223
let uri = self.uri.clone();
221224
let session_id = self.session_id.clone();
222225
let auth_header = self.auth_header.clone();
226+
let custom_headers = self.custom_headers.clone();
223227
let last_event_id = last_event_id.map(|s| s.to_owned());
224228
Box::pin(async move {
225229
client
226-
.get_stream(uri, session_id, last_event_id, auth_header)
230+
.get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
227231
.await
228232
})
229233
}
230234
}
231235

236+
/// Info retained for cleaning up the session when the worker exits.
237+
struct SessionCleanupInfo<C> {
238+
client: C,
239+
uri: Arc<str>,
240+
session_id: Arc<str>,
241+
auth_header: Option<String>,
242+
protocol_headers: HashMap<HeaderName, HeaderValue>,
243+
}
244+
232245
#[derive(Debug, Clone, Default)]
233246
pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
234247
pub client: C,
@@ -357,14 +370,29 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
357370
}
358371
None
359372
};
373+
// Extract the negotiated protocol version from the init response
374+
// and build a custom headers map that includes MCP-Protocol-Version
375+
// for all subsequent HTTP requests (per MCP 2025-06-18 spec).
376+
let protocol_headers = {
377+
let mut headers = config.custom_headers.clone();
378+
if let ServerJsonRpcMessage::Response(response) = &message {
379+
if let ServerResult::InitializeResult(init_result) = &response.result {
380+
if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
381+
// HeaderName::from_static requires lowercase
382+
headers.insert(HeaderName::from_static("mcp-protocol-version"), hv);
383+
}
384+
}
385+
}
386+
headers
387+
};
388+
360389
// Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
361-
let session_cleanup_info = session_id.as_ref().map(|sid| {
362-
(
363-
self.client.clone(),
364-
config.uri.clone(),
365-
sid.clone(),
366-
config.auth_header.clone(),
367-
)
390+
let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
391+
client: self.client.clone(),
392+
uri: config.uri.clone(),
393+
session_id: sid.clone(),
394+
auth_header: config.auth_header.clone(),
395+
protocol_headers: protocol_headers.clone(),
368396
});
369397

370398
context.send_to_handler(message).await?;
@@ -376,7 +404,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
376404
initialized_notification.message,
377405
session_id.clone(),
378406
config.auth_header.clone(),
379-
config.custom_headers.clone(),
407+
protocol_headers.clone(),
380408
)
381409
.await
382410
.map_err(WorkerQuitReason::fatal_context(
@@ -404,10 +432,17 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
404432
let transport_task_ct = transport_task_ct.clone();
405433
let config_uri = config.uri.clone();
406434
let config_auth_header = config.auth_header.clone();
435+
let spawn_headers = protocol_headers.clone();
407436

408437
streams.spawn(async move {
409438
match client
410-
.get_stream(uri.clone(), session_id.clone(), None, auth_header.clone())
439+
.get_stream(
440+
uri.clone(),
441+
session_id.clone(),
442+
None,
443+
auth_header.clone(),
444+
spawn_headers.clone(),
445+
)
411446
.await
412447
{
413448
Ok(stream) => {
@@ -418,6 +453,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
418453
session_id: session_id.clone(),
419454
uri: config_uri,
420455
auth_header: config_auth_header,
456+
custom_headers: spawn_headers,
421457
},
422458
retry_config,
423459
);
@@ -482,7 +518,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
482518
message,
483519
session_id.clone(),
484520
config.auth_header.clone(),
485-
config.custom_headers.clone(),
521+
protocol_headers.clone(),
486522
)
487523
.await;
488524
let send_result = match response {
@@ -504,6 +540,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
504540
session_id: session_id.clone(),
505541
uri: config.uri.clone(),
506542
auth_header: config.auth_header.clone(),
543+
custom_headers: protocol_headers.clone(),
507544
},
508545
self.config.retry_config.clone(),
509546
);
@@ -550,32 +587,41 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
550587

551588
// Cleanup session before returning (ensures close() waits for session deletion)
552589
// Use a timeout to prevent indefinite hangs if the server is unresponsive
553-
if let Some((client, url, session_id, auth_header)) = session_cleanup_info {
590+
if let Some(cleanup) = session_cleanup_info {
554591
const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
592+
let cleanup_session_id = cleanup.session_id.clone();
555593
match tokio::time::timeout(
556594
SESSION_CLEANUP_TIMEOUT,
557-
client.delete_session(url, session_id.clone(), auth_header),
595+
cleanup.client.delete_session(
596+
cleanup.uri,
597+
cleanup.session_id,
598+
cleanup.auth_header,
599+
cleanup.protocol_headers,
600+
),
558601
)
559602
.await
560603
{
561604
Ok(Ok(_)) => {
562-
tracing::info!(session_id = session_id.as_ref(), "delete session success")
605+
tracing::info!(
606+
session_id = cleanup_session_id.as_ref(),
607+
"delete session success"
608+
)
563609
}
564610
Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
565611
tracing::info!(
566-
session_id = session_id.as_ref(),
612+
session_id = cleanup_session_id.as_ref(),
567613
"server doesn't support delete session"
568614
)
569615
}
570616
Ok(Err(e)) => {
571617
tracing::error!(
572-
session_id = session_id.as_ref(),
618+
session_id = cleanup_session_id.as_ref(),
573619
"fail to delete session: {e}"
574620
);
575621
}
576622
Err(_elapsed) => {
577623
tracing::warn!(
578-
session_id = session_id.as_ref(),
624+
session_id = cleanup_session_id.as_ref(),
579625
"session cleanup timed out after {:?}",
580626
SESSION_CLEANUP_TIMEOUT
581627
);
@@ -652,6 +698,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
652698
/// _uri: Arc<str>,
653699
/// _session_id: Arc<str>,
654700
/// _auth_header: Option<String>,
701+
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
655702
/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
656703
/// todo!()
657704
/// }
@@ -662,6 +709,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
662709
/// _session_id: Arc<str>,
663710
/// _last_event_id: Option<String>,
664711
/// _auth_header: Option<String>,
712+
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
665713
/// ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
666714
/// todo!()
667715
/// }
@@ -737,6 +785,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
737785
/// _uri: Arc<str>,
738786
/// _session_id: Arc<str>,
739787
/// _auth_header: Option<String>,
788+
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
740789
/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
741790
/// todo!()
742791
/// }
@@ -747,6 +796,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
747796
/// _session_id: Arc<str>,
748797
/// _last_event_id: Option<String>,
749798
/// _auth_header: Option<String>,
799+
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
750800
/// ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
751801
/// todo!()
752802
/// }

0 commit comments

Comments
 (0)