Skip to content

Commit 8e22aa2

Browse files
jokemanfireCopilot
andauthored
fix(http): add host check (#764)
Signed-off-by: jokemanfire <hu.dingyang@zte.com.cn> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 012210b commit 8e22aa2

File tree

3 files changed

+279
-2
lines changed

3 files changed

+279
-2
lines changed

crates/rmcp/src/transport/common/http_header.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub const JSON_MIME_TYPE: &str = "application/json";
77
/// Reserved headers that must not be overridden by user-supplied custom headers.
88
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
99
/// injects it after initialization.
10+
#[allow(dead_code)]
1011
pub(crate) const RESERVED_HEADERS: &[&str] = &[
1112
"accept",
1213
HEADER_SESSION_ID,
@@ -36,6 +37,7 @@ pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), Stri
3637

3738
/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value.
3839
/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms.
40+
#[cfg(feature = "client-side-sse")]
3941
pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {
4042
let header_lowercase = header.to_ascii_lowercase();
4143
let scope_key = "scope=";

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
22

33
use bytes::Bytes;
44
use futures::{StreamExt, future::BoxFuture};
5-
use http::{Method, Request, Response, header::ALLOW};
5+
use http::{HeaderMap, Method, Request, Response, header::ALLOW};
66
use http_body::Body;
77
use http_body_util::{BodyExt, Full, combinators::BoxBody};
88
use tokio_stream::wrappers::ReceiverStream;
@@ -29,8 +29,8 @@ use crate::{
2929
},
3030
};
3131

32-
#[derive(Debug, Clone)]
3332
#[non_exhaustive]
33+
#[derive(Debug, Clone)]
3434
pub struct StreamableHttpServerConfig {
3535
/// The ping message duration for SSE connections.
3636
pub sse_keep_alive: Option<Duration>,
@@ -49,6 +49,16 @@ pub struct StreamableHttpServerConfig {
4949
/// When this token is cancelled, all active sessions are terminated and
5050
/// the server stops accepting new requests.
5151
pub cancellation_token: CancellationToken,
52+
/// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
53+
///
54+
/// By default, Streamable HTTP servers only accept loopback hosts to
55+
/// prevent DNS rebinding attacks against locally running servers. Public
56+
/// deployments should override this list with their own hostnames.
57+
/// examples:
58+
/// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
59+
/// or with ports:
60+
/// allowed_hosts = ["example.com", "example.com:8080"]
61+
pub allowed_hosts: Vec<String>,
5262
}
5363

5464
impl Default for StreamableHttpServerConfig {
@@ -59,11 +69,24 @@ impl Default for StreamableHttpServerConfig {
5969
stateful_mode: true,
6070
json_response: false,
6171
cancellation_token: CancellationToken::new(),
72+
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
6273
}
6374
}
6475
}
6576

6677
impl StreamableHttpServerConfig {
78+
pub fn with_allowed_hosts(
79+
mut self,
80+
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
81+
) -> Self {
82+
self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
83+
self
84+
}
85+
/// Disable allowed hosts. This will allow requests with any `Host` header, which is NOT recommended for public deployments.
86+
pub fn disable_allowed_hosts(mut self) -> Self {
87+
self.allowed_hosts.clear();
88+
self
89+
}
6790
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
6891
self.sse_keep_alive = duration;
6992
self
@@ -130,6 +153,97 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
130153
Ok(())
131154
}
132155

156+
fn forbidden_response(message: impl Into<String>) -> BoxResponse {
157+
Response::builder()
158+
.status(http::StatusCode::FORBIDDEN)
159+
.body(Full::new(Bytes::from(message.into())).boxed())
160+
.expect("valid response")
161+
}
162+
163+
fn normalize_host(host: &str) -> String {
164+
host.trim_matches('[')
165+
.trim_matches(']')
166+
.to_ascii_lowercase()
167+
}
168+
169+
#[derive(Debug, Clone, PartialEq, Eq)]
170+
struct NormalizedAuthority {
171+
host: String,
172+
port: Option<u16>,
173+
}
174+
175+
fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
176+
NormalizedAuthority {
177+
host: normalize_host(host),
178+
port,
179+
}
180+
}
181+
182+
fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
183+
let allowed = allowed.trim();
184+
if allowed.is_empty() {
185+
return None;
186+
}
187+
188+
if let Ok(authority) = http::uri::Authority::try_from(allowed) {
189+
return Some(normalize_authority(authority.host(), authority.port_u16()));
190+
}
191+
192+
Some(normalize_authority(allowed, None))
193+
}
194+
195+
fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
196+
if allowed_hosts.is_empty() {
197+
// If the allowed hosts list is empty, allow all hosts (not recommended).
198+
return true;
199+
}
200+
allowed_hosts
201+
.iter()
202+
.filter_map(|allowed| parse_allowed_authority(allowed))
203+
.any(|allowed| {
204+
allowed.host == host.host
205+
&& match allowed.port {
206+
Some(port) => host.port == Some(port),
207+
None => true,
208+
}
209+
})
210+
}
211+
212+
fn bad_request_response(message: &str) -> BoxResponse {
213+
let body = Full::from(message.to_string()).boxed();
214+
215+
http::Response::builder()
216+
.status(http::StatusCode::BAD_REQUEST)
217+
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
218+
.body(body)
219+
.expect("failed to build bad request response")
220+
}
221+
222+
fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
223+
let Some(host) = headers.get(http::header::HOST) else {
224+
return Err(bad_request_response("Bad Request: missing Host header"));
225+
};
226+
227+
let host = host
228+
.to_str()
229+
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
230+
let authority = http::uri::Authority::try_from(host)
231+
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
232+
Ok(normalize_authority(authority.host(), authority.port_u16()))
233+
}
234+
235+
fn validate_dns_rebinding_headers(
236+
headers: &HeaderMap,
237+
config: &StreamableHttpServerConfig,
238+
) -> Result<(), BoxResponse> {
239+
let host = parse_host_header(headers)?;
240+
if !host_is_allowed(&host, &config.allowed_hosts) {
241+
return Err(forbidden_response("Forbidden: Host header is not allowed"));
242+
}
243+
244+
Ok(())
245+
}
246+
133247
/// # Streamable HTTP server
134248
///
135249
/// An HTTP service that implements the
@@ -279,6 +393,9 @@ where
279393
B: Body + Send + 'static,
280394
B::Error: Display,
281395
{
396+
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
397+
return response;
398+
}
282399
let method = request.method().clone();
283400
let allowed_methods = match self.config.stateful_mode {
284401
true => "GET, POST, DELETE",

crates/rmcp/tests/test_custom_headers.rs

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
761761
.method(Method::POST)
762762
.header("Accept", "application/json, text/event-stream")
763763
.header(CONTENT_TYPE, "application/json")
764+
.header("Host", "localhost:8080")
764765
.body(Full::new(Bytes::from(init_body.to_string())))
765766
.unwrap();
766767

@@ -785,6 +786,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
785786
.method(Method::POST)
786787
.header("Accept", "application/json, text/event-stream")
787788
.header(CONTENT_TYPE, "application/json")
789+
.header("Host", "localhost:8080")
788790
.header("mcp-session-id", &session_id)
789791
.header("mcp-protocol-version", "2025-03-26")
790792
.body(Full::new(Bytes::from(initialized_body.to_string())))
@@ -802,6 +804,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
802804
.method(Method::POST)
803805
.header("Accept", "application/json, text/event-stream")
804806
.header(CONTENT_TYPE, "application/json")
807+
.header("Host", "localhost:8080")
805808
.header("mcp-session-id", &session_id)
806809
.header("mcp-protocol-version", "2025-03-26")
807810
.body(Full::new(Bytes::from(valid_body.to_string())))
@@ -823,6 +826,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
823826
.method(Method::POST)
824827
.header("Accept", "application/json, text/event-stream")
825828
.header(CONTENT_TYPE, "application/json")
829+
.header("Host", "localhost:8080")
826830
.header("mcp-session-id", &session_id)
827831
.header("mcp-protocol-version", "9999-01-01")
828832
.body(Full::new(Bytes::from(invalid_body.to_string())))
@@ -844,6 +848,7 @@ async fn test_server_rejects_unsupported_protocol_version() {
844848
.method(Method::POST)
845849
.header("Accept", "application/json, text/event-stream")
846850
.header(CONTENT_TYPE, "application/json")
851+
.header("Host", "localhost:8080")
847852
.header("mcp-session-id", &session_id)
848853
.body(Full::new(Bytes::from(no_version_body.to_string())))
849854
.unwrap();
@@ -870,3 +875,156 @@ fn test_protocol_version_utilities() {
870875
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26));
871876
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18));
872877
}
878+
879+
/// Integration test: Verify server validates only the Host header for DNS rebinding protection
880+
#[tokio::test]
881+
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
882+
async fn test_server_validates_host_header_for_dns_rebinding_protection() {
883+
use std::sync::Arc;
884+
885+
use bytes::Bytes;
886+
use http::{Method, Request, header::CONTENT_TYPE};
887+
use http_body_util::Full;
888+
use rmcp::{
889+
handler::server::ServerHandler,
890+
model::{ServerCapabilities, ServerInfo},
891+
transport::streamable_http_server::{
892+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
893+
},
894+
};
895+
use serde_json::json;
896+
897+
#[derive(Clone)]
898+
struct TestHandler;
899+
900+
impl ServerHandler for TestHandler {
901+
fn get_info(&self) -> ServerInfo {
902+
ServerInfo::new(ServerCapabilities::builder().build())
903+
}
904+
}
905+
906+
let service = StreamableHttpService::new(
907+
|| Ok(TestHandler),
908+
Arc::new(LocalSessionManager::default()),
909+
StreamableHttpServerConfig::default(),
910+
);
911+
912+
let init_body = json!({
913+
"jsonrpc": "2.0",
914+
"id": 1,
915+
"method": "initialize",
916+
"params": {
917+
"protocolVersion": "2025-03-26",
918+
"capabilities": {},
919+
"clientInfo": {
920+
"name": "test-client",
921+
"version": "1.0.0"
922+
}
923+
}
924+
});
925+
926+
let allowed_request = Request::builder()
927+
.method(Method::POST)
928+
.header("Accept", "application/json, text/event-stream")
929+
.header(CONTENT_TYPE, "application/json")
930+
.header("Host", "localhost:8080")
931+
.header("Origin", "http://localhost:8080")
932+
.body(Full::new(Bytes::from(init_body.to_string())))
933+
.unwrap();
934+
935+
let response = service.handle(allowed_request).await;
936+
assert_eq!(response.status(), http::StatusCode::OK);
937+
938+
let bad_host_request = Request::builder()
939+
.method(Method::POST)
940+
.header("Accept", "application/json, text/event-stream")
941+
.header(CONTENT_TYPE, "application/json")
942+
.header("Host", "attacker.example")
943+
.body(Full::new(Bytes::from(init_body.to_string())))
944+
.unwrap();
945+
946+
let response = service.handle(bad_host_request).await;
947+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
948+
949+
let ignored_origin_request = Request::builder()
950+
.method(Method::POST)
951+
.header("Accept", "application/json, text/event-stream")
952+
.header(CONTENT_TYPE, "application/json")
953+
.header("Host", "localhost:8080")
954+
.header("Origin", "http://attacker.example")
955+
.body(Full::new(Bytes::from(init_body.to_string())))
956+
.unwrap();
957+
958+
let response = service.handle(ignored_origin_request).await;
959+
assert_eq!(response.status(), http::StatusCode::OK);
960+
}
961+
962+
/// Integration test: Verify server can enforce an allowed Host port when configured
963+
#[tokio::test]
964+
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
965+
async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
966+
use std::sync::Arc;
967+
968+
use bytes::Bytes;
969+
use http::{Method, Request, header::CONTENT_TYPE};
970+
use http_body_util::Full;
971+
use rmcp::{
972+
handler::server::ServerHandler,
973+
model::{ServerCapabilities, ServerInfo},
974+
transport::streamable_http_server::{
975+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
976+
},
977+
};
978+
use serde_json::json;
979+
980+
#[derive(Clone)]
981+
struct TestHandler;
982+
983+
impl ServerHandler for TestHandler {
984+
fn get_info(&self) -> ServerInfo {
985+
ServerInfo::new(ServerCapabilities::builder().build())
986+
}
987+
}
988+
989+
let service = StreamableHttpService::new(
990+
|| Ok(TestHandler),
991+
Arc::new(LocalSessionManager::default()),
992+
StreamableHttpServerConfig::default().with_allowed_hosts(["localhost:8080"]),
993+
);
994+
995+
let init_body = json!({
996+
"jsonrpc": "2.0",
997+
"id": 1,
998+
"method": "initialize",
999+
"params": {
1000+
"protocolVersion": "2025-03-26",
1001+
"capabilities": {},
1002+
"clientInfo": {
1003+
"name": "test-client",
1004+
"version": "1.0.0"
1005+
}
1006+
}
1007+
});
1008+
1009+
let allowed_request = Request::builder()
1010+
.method(Method::POST)
1011+
.header("Accept", "application/json, text/event-stream")
1012+
.header(CONTENT_TYPE, "application/json")
1013+
.header("Host", "localhost:8080")
1014+
.body(Full::new(Bytes::from(init_body.to_string())))
1015+
.unwrap();
1016+
1017+
let response = service.handle(allowed_request).await;
1018+
assert_eq!(response.status(), http::StatusCode::OK);
1019+
1020+
let wrong_port_request = Request::builder()
1021+
.method(Method::POST)
1022+
.header("Accept", "application/json, text/event-stream")
1023+
.header(CONTENT_TYPE, "application/json")
1024+
.header("Host", "localhost:3000")
1025+
.body(Full::new(Bytes::from(init_body.to_string())))
1026+
.unwrap();
1027+
1028+
let response = service.handle(wrong_port_request).await;
1029+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1030+
}

0 commit comments

Comments
 (0)