Skip to content

Commit c330fed

Browse files
authored
fix: reject init header/body version mismatch (#853)
1 parent d328751 commit c330fed

3 files changed

Lines changed: 233 additions & 12 deletions

File tree

crates/rmcp/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ name = "test_streamable_http_json_response"
272272
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
273273
path = "tests/test_streamable_http_json_response.rs"
274274

275+
[[test]]
276+
name = "test_streamable_http_protocol_version"
277+
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
278+
path = "tests/test_streamable_http_protocol_version.rs"
279+
275280
[[test]]
276281
name = "test_streamable_http_4xx_error_body"
277282
required-features = ["transport-streamable-http-client", "transport-streamable-http-client-reqwest"]

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration};
1+
use std::{
2+
borrow::Cow, collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration,
3+
};
24

35
use bytes::Bytes;
46
use futures::{StreamExt, future::BoxFuture};
@@ -14,8 +16,8 @@ use super::session::{
1416
use crate::{
1517
RoleServer,
1618
model::{
17-
ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest,
18-
InitializedNotification, ProtocolVersion,
19+
ClientJsonRpcMessage, ClientNotification, ClientRequest, ErrorData, GetExtensions,
20+
InitializeRequest, InitializedNotification, JsonRpcError, ProtocolVersion, RequestId,
1921
},
2022
serve_server,
2123
service::serve_directly,
@@ -209,6 +211,54 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
209211
Ok(())
210212
}
211213

214+
fn invalid_request_jsonrpc_response(
215+
id: Option<RequestId>,
216+
message: impl Into<Cow<'static, str>>,
217+
) -> BoxResponse {
218+
let err = JsonRpcError::new(id, ErrorData::invalid_request(message, None));
219+
let body = serde_json::to_vec(&err).expect("serialize JsonRpcError");
220+
Response::builder()
221+
.status(http::StatusCode::BAD_REQUEST)
222+
.header(http::header::CONTENT_TYPE, JSON_MIME_TYPE)
223+
.body(Full::new(Bytes::from(body)).boxed())
224+
.expect("valid response")
225+
}
226+
227+
#[expect(
228+
clippy::result_large_err,
229+
reason = "BoxResponse is intentionally large; matches other handlers in this file"
230+
)]
231+
/// Absent header is allowed; the first initialize round-trip may legitimately omit it.
232+
fn validate_header_matches_init_body(
233+
headers: &http::HeaderMap,
234+
body_version: &str,
235+
request_id: Option<RequestId>,
236+
) -> Result<(), BoxResponse> {
237+
let Some(header_value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) else {
238+
return Ok(());
239+
};
240+
let header_str = header_value.to_str().map_err(|_| {
241+
invalid_request_jsonrpc_response(
242+
request_id.clone(),
243+
"Invalid Request: MCP-Protocol-Version header is not valid UTF-8",
244+
)
245+
})?;
246+
if header_str != body_version {
247+
tracing::warn!(
248+
header = header_str,
249+
body = body_version,
250+
"rejecting initialize: MCP-Protocol-Version header does not match params.protocolVersion"
251+
);
252+
return Err(invalid_request_jsonrpc_response(
253+
request_id,
254+
format!(
255+
"Invalid Request: MCP-Protocol-Version header ({header_str}) does not match initialize params.protocolVersion ({body_version})"
256+
),
257+
));
258+
}
259+
Ok(())
260+
}
261+
212262
fn forbidden_response(message: impl Into<String>) -> BoxResponse {
213263
Response::builder()
214264
.status(http::StatusCode::FORBIDDEN)
@@ -1095,9 +1145,15 @@ where
10951145
None
10961146
};
10971147
if let ClientJsonRpcMessage::Request(req) = &mut message {
1098-
if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
1148+
let ClientRequest::InitializeRequest(init_req) = &req.request else {
10991149
return Err(unexpected_message_response("initialize request"));
1100-
}
1150+
};
1151+
// Reject mismatched MCP-Protocol-Version header before binding the session to anything.
1152+
validate_header_matches_init_body(
1153+
&part.headers,
1154+
init_req.params.protocol_version.as_str(),
1155+
Some(req.id.clone()),
1156+
)?;
11011157
// inject request part to extensions
11021158
req.request.extensions_mut().insert(part);
11031159
} else {
@@ -1163,13 +1219,24 @@ where
11631219
Ok(response)
11641220
}
11651221
} else {
1166-
// Stateless mode: validate MCP-Protocol-Version on non-init requests
1167-
let is_init = matches!(
1168-
&message,
1169-
ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_))
1170-
);
1171-
if !is_init {
1172-
validate_protocol_version_header(&part.headers)?;
1222+
// Stateless mode:
1223+
// - on initialize: the header (if present) must match `params.protocolVersion`
1224+
// - on every other request: the header must name a known version.
1225+
match &message {
1226+
ClientJsonRpcMessage::Request(req) => {
1227+
if let ClientRequest::InitializeRequest(init_req) = &req.request {
1228+
validate_header_matches_init_body(
1229+
&part.headers,
1230+
init_req.params.protocol_version.as_str(),
1231+
Some(req.id.clone()),
1232+
)?;
1233+
} else {
1234+
validate_protocol_version_header(&part.headers)?;
1235+
}
1236+
}
1237+
_ => {
1238+
validate_protocol_version_header(&part.headers)?;
1239+
}
11731240
}
11741241
let service = self
11751242
.get_service()
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#![cfg(not(feature = "local"))]
2+
//! Regression tests for the `MCP-Protocol-Version` header / initialize body consistency check.
3+
use rmcp::transport::streamable_http_server::{
4+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
5+
};
6+
use tokio_util::sync::CancellationToken;
7+
8+
mod common;
9+
use common::calculator::Calculator;
10+
11+
fn init_body(body_version: &str) -> String {
12+
format!(
13+
r#"{{"jsonrpc":"2.0","id":1,"method":"initialize","params":{{"protocolVersion":"{body_version}","capabilities":{{}},"clientInfo":{{"name":"test","version":"1.0"}}}}}}"#
14+
)
15+
}
16+
17+
async fn spawn_server(
18+
config: StreamableHttpServerConfig,
19+
) -> (reqwest::Client, String, CancellationToken) {
20+
let ct = config.cancellation_token.clone();
21+
let service: StreamableHttpService<Calculator, LocalSessionManager> =
22+
StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), config);
23+
24+
let router = axum::Router::new().nest_service("/mcp", service);
25+
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
26+
let addr = tcp_listener.local_addr().unwrap();
27+
28+
tokio::spawn({
29+
let ct = ct.clone();
30+
async move {
31+
let _ = axum::serve(tcp_listener, router)
32+
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
33+
.await;
34+
}
35+
});
36+
37+
let client = reqwest::Client::new();
38+
let base_url = format!("http://{addr}/mcp");
39+
(client, base_url, ct)
40+
}
41+
42+
fn stateless_json_config() -> StreamableHttpServerConfig {
43+
StreamableHttpServerConfig::default()
44+
.with_stateful_mode(false)
45+
.with_json_response(true)
46+
.with_sse_keep_alive(None)
47+
.with_cancellation_token(CancellationToken::new())
48+
}
49+
50+
fn stateful_config() -> StreamableHttpServerConfig {
51+
StreamableHttpServerConfig::default()
52+
.with_stateful_mode(true)
53+
.with_sse_keep_alive(None)
54+
.with_cancellation_token(CancellationToken::new())
55+
}
56+
57+
async fn post_init(
58+
client: &reqwest::Client,
59+
url: &str,
60+
header: Option<&str>,
61+
body_version: &str,
62+
) -> reqwest::Response {
63+
let mut req = client
64+
.post(url)
65+
.header("Content-Type", "application/json")
66+
.header("Accept", "application/json, text/event-stream")
67+
.body(init_body(body_version));
68+
if let Some(h) = header {
69+
req = req.header("MCP-Protocol-Version", h);
70+
}
71+
req.send().await.expect("send initialize request")
72+
}
73+
74+
#[tokio::test]
75+
async fn stateless_init_rejects_when_header_older_than_body() -> anyhow::Result<()> {
76+
let (client, url, ct) = spawn_server(stateless_json_config()).await;
77+
78+
let response = post_init(&client, &url, Some("2025-03-26"), "2025-11-25").await;
79+
assert_eq!(response.status(), 400);
80+
81+
let body: serde_json::Value = response.json().await?;
82+
assert_eq!(body["error"]["code"], -32600);
83+
assert!(
84+
body["error"]["message"]
85+
.as_str()
86+
.unwrap_or_default()
87+
.contains("MCP-Protocol-Version"),
88+
"expected error message to mention the header, got: {body}"
89+
);
90+
91+
ct.cancel();
92+
Ok(())
93+
}
94+
95+
#[tokio::test]
96+
async fn stateless_init_rejects_when_header_newer_than_body() -> anyhow::Result<()> {
97+
let (client, url, ct) = spawn_server(stateless_json_config()).await;
98+
99+
let response = post_init(&client, &url, Some("2025-11-25"), "2025-03-26").await;
100+
assert_eq!(response.status(), 400);
101+
102+
let body: serde_json::Value = response.json().await?;
103+
assert_eq!(body["error"]["code"], -32600);
104+
105+
ct.cancel();
106+
Ok(())
107+
}
108+
109+
#[tokio::test]
110+
async fn stateless_init_accepts_when_header_matches_body() -> anyhow::Result<()> {
111+
let (client, url, ct) = spawn_server(stateless_json_config()).await;
112+
113+
let response = post_init(&client, &url, Some("2025-11-25"), "2025-11-25").await;
114+
assert_eq!(response.status(), 200);
115+
116+
let body: serde_json::Value = response.json().await?;
117+
assert!(
118+
body["result"].is_object(),
119+
"expected an InitializeResult, got: {body}"
120+
);
121+
122+
ct.cancel();
123+
Ok(())
124+
}
125+
126+
#[tokio::test]
127+
async fn stateless_init_accepts_when_header_absent() -> anyhow::Result<()> {
128+
let (client, url, ct) = spawn_server(stateless_json_config()).await;
129+
130+
let response = post_init(&client, &url, None, "2025-11-25").await;
131+
assert_eq!(response.status(), 200);
132+
133+
ct.cancel();
134+
Ok(())
135+
}
136+
137+
#[tokio::test]
138+
async fn stateful_init_rejects_when_header_mismatches_body() -> anyhow::Result<()> {
139+
let (client, url, ct) = spawn_server(stateful_config()).await;
140+
141+
let response = post_init(&client, &url, Some("2024-11-05"), "2025-11-25").await;
142+
assert_eq!(response.status(), 400);
143+
144+
let body: serde_json::Value = response.json().await?;
145+
assert_eq!(body["error"]["code"], -32600);
146+
147+
ct.cancel();
148+
Ok(())
149+
}

0 commit comments

Comments
 (0)