Skip to content

Commit 9753d61

Browse files
authored
feat(http): add Origin header validation (#823)
1 parent 63583b1 commit 9753d61

2 files changed

Lines changed: 206 additions & 0 deletions

File tree

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ pub struct StreamableHttpServerConfig {
6464
/// or with ports:
6565
/// allowed_hosts = ["example.com", "example.com:8080"]
6666
pub allowed_hosts: Vec<String>,
67+
/// Allowed browser origins for inbound `Origin` validation.
68+
///
69+
/// Defaults to an empty list, which disables Origin validation. When
70+
/// non-empty, requests carrying an `Origin` header must match per RFC 6454
71+
/// `(scheme, host, port)`; missing-`Origin` requests still pass. Entries
72+
/// must include a scheme; `"null"` matches the browser's `Origin: null`.
73+
/// examples:
74+
/// allowed_origins = ["https://app.example.com", "http://localhost:8080"]
75+
pub allowed_origins: Vec<String>,
6776
/// Optional external session store for cross-instance recovery.
6877
///
6978
/// When set, [`SessionState`] (the client's `initialize` parameters) is
@@ -103,6 +112,7 @@ impl Default for StreamableHttpServerConfig {
103112
json_response: false,
104113
cancellation_token: CancellationToken::new(),
105114
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
115+
allowed_origins: vec![],
106116
session_store: None,
107117
}
108118
}
@@ -121,6 +131,18 @@ impl StreamableHttpServerConfig {
121131
self.allowed_hosts.clear();
122132
self
123133
}
134+
pub fn with_allowed_origins(
135+
mut self,
136+
allowed_origins: impl IntoIterator<Item = impl Into<String>>,
137+
) -> Self {
138+
self.allowed_origins = allowed_origins.into_iter().map(Into::into).collect();
139+
self
140+
}
141+
/// Disable Origin validation, reverting to the default ignore-Origin behavior.
142+
pub fn disable_allowed_origins(mut self) -> Self {
143+
self.allowed_origins.clear();
144+
self
145+
}
124146
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {
125147
self.sse_keep_alive = duration;
126148
self
@@ -243,6 +265,59 @@ fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool
243265
})
244266
}
245267

268+
#[derive(Debug, Clone, PartialEq, Eq)]
269+
enum NormalizedOrigin {
270+
Null,
271+
Tuple {
272+
scheme: String,
273+
host: String,
274+
port: Option<u16>,
275+
},
276+
}
277+
278+
fn parse_origin_value(value: &str) -> Option<NormalizedOrigin> {
279+
let value = value.trim();
280+
if value.is_empty() {
281+
return None;
282+
}
283+
if value.eq_ignore_ascii_case("null") {
284+
return Some(NormalizedOrigin::Null);
285+
}
286+
let uri = http::Uri::try_from(value).ok()?;
287+
let scheme = uri.scheme_str()?.to_ascii_lowercase();
288+
let authority = uri.authority()?;
289+
Some(NormalizedOrigin::Tuple {
290+
scheme,
291+
host: normalize_host(authority.host()),
292+
port: authority.port_u16(),
293+
})
294+
}
295+
296+
fn origin_is_allowed(origin: &NormalizedOrigin, allowed_origins: &[String]) -> bool {
297+
if allowed_origins.is_empty() {
298+
return true;
299+
}
300+
allowed_origins
301+
.iter()
302+
.filter_map(|raw| parse_origin_value(raw))
303+
.any(|allowed| match (&allowed, origin) {
304+
(NormalizedOrigin::Null, NormalizedOrigin::Null) => true,
305+
(
306+
NormalizedOrigin::Tuple {
307+
scheme: a_scheme,
308+
host: a_host,
309+
port: a_port,
310+
},
311+
NormalizedOrigin::Tuple {
312+
scheme: o_scheme,
313+
host: o_host,
314+
port: o_port,
315+
},
316+
) => a_scheme == o_scheme && a_host == o_host && (a_port.is_none() || a_port == o_port),
317+
_ => false,
318+
})
319+
}
320+
246321
fn bad_request_response(message: &str) -> BoxResponse {
247322
let body = Full::from(message.to_string()).boxed();
248323

@@ -274,7 +349,30 @@ fn validate_dns_rebinding_headers(
274349
if !host_is_allowed(&host, &config.allowed_hosts) {
275350
return Err(forbidden_response("Forbidden: Host header is not allowed"));
276351
}
352+
validate_origin_header(headers, &config.allowed_origins)?;
353+
Ok(())
354+
}
277355

356+
fn validate_origin_header(
357+
headers: &HeaderMap,
358+
allowed_origins: &[String],
359+
) -> Result<(), BoxResponse> {
360+
if allowed_origins.is_empty() {
361+
return Ok(());
362+
}
363+
let Some(origin_header) = headers.get(http::header::ORIGIN) else {
364+
return Ok(());
365+
};
366+
let origin_str = origin_header
367+
.to_str()
368+
.map_err(|_| bad_request_response("Bad Request: Invalid Origin header encoding"))?;
369+
let origin = parse_origin_value(origin_str)
370+
.ok_or_else(|| bad_request_response("Bad Request: Invalid Origin header"))?;
371+
if !origin_is_allowed(&origin, allowed_origins) {
372+
return Err(forbidden_response(
373+
"Forbidden: Origin header is not allowed",
374+
));
375+
}
278376
Ok(())
279377
}
280378

crates/rmcp/tests/test_custom_headers.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,3 +1030,111 @@ async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
10301030
let response = service.handle(wrong_port_request).await;
10311031
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
10321032
}
1033+
1034+
#[cfg(all(feature = "transport-streamable-http-server", feature = "server"))]
1035+
mod origin_validation {
1036+
use std::sync::Arc;
1037+
1038+
use bytes::Bytes;
1039+
use http::{Method, Request, header::CONTENT_TYPE};
1040+
use http_body_util::Full;
1041+
use rmcp::{
1042+
handler::server::ServerHandler,
1043+
model::{ServerCapabilities, ServerInfo},
1044+
transport::streamable_http_server::{
1045+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
1046+
},
1047+
};
1048+
use serde_json::json;
1049+
1050+
#[derive(Clone)]
1051+
struct TestHandler;
1052+
1053+
impl ServerHandler for TestHandler {
1054+
fn get_info(&self) -> ServerInfo {
1055+
ServerInfo::new(ServerCapabilities::builder().build())
1056+
}
1057+
}
1058+
1059+
fn service_with_allowed_origins(
1060+
origins: &[&str],
1061+
) -> StreamableHttpService<TestHandler, LocalSessionManager> {
1062+
StreamableHttpService::new(
1063+
|| Ok(TestHandler),
1064+
Arc::new(LocalSessionManager::default()),
1065+
StreamableHttpServerConfig::default().with_allowed_origins(origins.iter().copied()),
1066+
)
1067+
}
1068+
1069+
fn init_request(origin: Option<&str>) -> Request<Full<Bytes>> {
1070+
let init_body = json!({
1071+
"jsonrpc": "2.0",
1072+
"id": 1,
1073+
"method": "initialize",
1074+
"params": {
1075+
"protocolVersion": "2025-03-26",
1076+
"capabilities": {},
1077+
"clientInfo": {"name": "test-client", "version": "1.0.0"}
1078+
}
1079+
});
1080+
let mut builder = Request::builder()
1081+
.method(Method::POST)
1082+
.header("Accept", "application/json, text/event-stream")
1083+
.header(CONTENT_TYPE, "application/json")
1084+
.header("Host", "localhost:8080");
1085+
if let Some(origin) = origin {
1086+
builder = builder.header("Origin", origin);
1087+
}
1088+
builder
1089+
.body(Full::new(Bytes::from(init_body.to_string())))
1090+
.unwrap()
1091+
}
1092+
1093+
#[tokio::test]
1094+
async fn allowlisted_origin_is_allowed() {
1095+
let service = service_with_allowed_origins(&["http://localhost:8080"]);
1096+
let response = service
1097+
.handle(init_request(Some("http://localhost:8080")))
1098+
.await;
1099+
assert_eq!(response.status(), http::StatusCode::OK);
1100+
}
1101+
1102+
#[tokio::test]
1103+
async fn non_allowlisted_origin_is_forbidden() {
1104+
let service = service_with_allowed_origins(&["http://localhost:8080"]);
1105+
let response = service
1106+
.handle(init_request(Some("http://attacker.example")))
1107+
.await;
1108+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1109+
}
1110+
1111+
#[tokio::test]
1112+
async fn missing_origin_passes_through() {
1113+
let service = service_with_allowed_origins(&["http://localhost:8080"]);
1114+
let response = service.handle(init_request(None)).await;
1115+
assert_eq!(response.status(), http::StatusCode::OK);
1116+
}
1117+
1118+
#[tokio::test]
1119+
async fn scheme_mismatch_is_forbidden() {
1120+
let service = service_with_allowed_origins(&["http://localhost:8080"]);
1121+
let response = service
1122+
.handle(init_request(Some("https://localhost:8080")))
1123+
.await;
1124+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1125+
}
1126+
1127+
#[tokio::test]
1128+
async fn null_origin_is_allowed_when_allowlisted() {
1129+
let service = service_with_allowed_origins(&["null"]);
1130+
let response = service.handle(init_request(Some("null"))).await;
1131+
assert_eq!(response.status(), http::StatusCode::OK);
1132+
}
1133+
1134+
#[tokio::test]
1135+
async fn null_origin_is_forbidden_when_not_allowlisted() {
1136+
let service = service_with_allowed_origins(&["http://localhost:8080"]);
1137+
let response = service.handle(init_request(Some("null"))).await;
1138+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1139+
}
1140+
}

0 commit comments

Comments
 (0)