Skip to content

Commit 9ec2cec

Browse files
committed
fix(http): fall back to :authority for HTTP/2
1 parent 4cf7873 commit 9ec2cec

2 files changed

Lines changed: 122 additions & 22 deletions

File tree

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -328,34 +328,43 @@ fn bad_request_response(message: &str) -> BoxResponse {
328328
.expect("failed to build bad request response")
329329
}
330330

331-
fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
332-
let Some(host) = headers.get(http::header::HOST) else {
333-
tracing::warn!("rejected request with missing Host header");
334-
return Err(bad_request_response("Bad Request: missing Host header"));
335-
};
336-
337-
let host_str = host
338-
.to_str()
339-
.inspect_err(|_| {
340-
tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header");
341-
})
342-
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
343-
let authority = http::uri::Authority::try_from(host_str)
344-
.inspect_err(|_| {
345-
tracing::warn!(
346-
host = host_str,
347-
"rejected request with malformed Host header"
348-
);
349-
})
350-
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
331+
fn parse_host_header(
332+
uri: &http::Uri,
333+
headers: &HeaderMap,
334+
) -> Result<NormalizedAuthority, BoxResponse> {
335+
if let Some(host) = headers.get(http::header::HOST) {
336+
let host_str = host
337+
.to_str()
338+
.inspect_err(|_| {
339+
tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header");
340+
})
341+
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
342+
let authority = http::uri::Authority::try_from(host_str)
343+
.inspect_err(|_| {
344+
tracing::warn!(
345+
host = host_str,
346+
"rejected request with malformed Host header"
347+
);
348+
})
349+
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
350+
return Ok(normalize_authority(authority.host(), authority.port_u16()));
351+
}
352+
// HTTP/2 carries the host in the `:authority` pseudo-header; middleware
353+
// such as `axum::Router::nest` can drop the `Host` header that hyper
354+
// synthesizes from it. Fall back to the URI authority directly.
355+
let authority = uri.authority().ok_or_else(|| {
356+
tracing::warn!("rejected request with missing Host header and no :authority");
357+
bad_request_response("Bad Request: missing Host header")
358+
})?;
351359
Ok(normalize_authority(authority.host(), authority.port_u16()))
352360
}
353361

354362
fn validate_dns_rebinding_headers(
363+
uri: &http::Uri,
355364
headers: &HeaderMap,
356365
config: &StreamableHttpServerConfig,
357366
) -> Result<(), BoxResponse> {
358-
let host = parse_host_header(headers)?;
367+
let host = parse_host_header(uri, headers)?;
359368
if !host_is_allowed(&host, &config.allowed_hosts) {
360369
tracing::warn!(
361370
host = ?host,
@@ -806,7 +815,9 @@ where
806815
B: Body + Send + 'static,
807816
B::Error: Display,
808817
{
809-
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
818+
if let Err(response) =
819+
validate_dns_rebinding_headers(request.uri(), request.headers(), &self.config)
820+
{
810821
return response;
811822
}
812823
let method = request.method().clone();

crates/rmcp/tests/test_custom_headers.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,95 @@ async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
10311031
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
10321032
}
10331033

1034+
/// Integration test: Verify the validator falls back to the URI authority when
1035+
/// the Host header is absent (HTTP/2 :authority pseudo-header scenario).
1036+
#[tokio::test]
1037+
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
1038+
async fn test_server_falls_back_to_uri_authority_when_host_header_missing() {
1039+
use std::sync::Arc;
1040+
1041+
use bytes::Bytes;
1042+
use http::{Method, Request, header::CONTENT_TYPE};
1043+
use http_body_util::Full;
1044+
use rmcp::{
1045+
handler::server::ServerHandler,
1046+
model::{ServerCapabilities, ServerInfo},
1047+
transport::streamable_http_server::{
1048+
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
1049+
},
1050+
};
1051+
use serde_json::json;
1052+
1053+
#[derive(Clone)]
1054+
struct TestHandler;
1055+
1056+
impl ServerHandler for TestHandler {
1057+
fn get_info(&self) -> ServerInfo {
1058+
ServerInfo::new(ServerCapabilities::builder().build())
1059+
}
1060+
}
1061+
1062+
let service = StreamableHttpService::new(
1063+
|| Ok(TestHandler),
1064+
Arc::new(LocalSessionManager::default()),
1065+
StreamableHttpServerConfig::default(),
1066+
);
1067+
1068+
let init_body = json!({
1069+
"jsonrpc": "2.0",
1070+
"id": 1,
1071+
"method": "initialize",
1072+
"params": {
1073+
"protocolVersion": "2025-03-26",
1074+
"capabilities": {},
1075+
"clientInfo": {
1076+
"name": "test-client",
1077+
"version": "1.0.0"
1078+
}
1079+
}
1080+
});
1081+
1082+
// Allowed authority via URI only — no Host header.
1083+
let allowed_request = Request::builder()
1084+
.method(Method::POST)
1085+
.uri("http://localhost:8080/")
1086+
.header("Accept", "application/json, text/event-stream")
1087+
.header(CONTENT_TYPE, "application/json")
1088+
.body(Full::new(Bytes::from(init_body.to_string())))
1089+
.unwrap();
1090+
assert!(allowed_request.headers().get("Host").is_none());
1091+
1092+
let response = service.handle(allowed_request).await;
1093+
assert_eq!(response.status(), http::StatusCode::OK);
1094+
1095+
// Disallowed authority via URI only — no Host header.
1096+
let bad_request = Request::builder()
1097+
.method(Method::POST)
1098+
.uri("http://attacker.example/")
1099+
.header("Accept", "application/json, text/event-stream")
1100+
.header(CONTENT_TYPE, "application/json")
1101+
.body(Full::new(Bytes::from(init_body.to_string())))
1102+
.unwrap();
1103+
assert!(bad_request.headers().get("Host").is_none());
1104+
1105+
let response = service.handle(bad_request).await;
1106+
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
1107+
1108+
// Neither Host header nor URI authority — still a 400.
1109+
let missing_request = Request::builder()
1110+
.method(Method::POST)
1111+
.uri("/")
1112+
.header("Accept", "application/json, text/event-stream")
1113+
.header(CONTENT_TYPE, "application/json")
1114+
.body(Full::new(Bytes::from(init_body.to_string())))
1115+
.unwrap();
1116+
assert!(missing_request.headers().get("Host").is_none());
1117+
assert!(missing_request.uri().authority().is_none());
1118+
1119+
let response = service.handle(missing_request).await;
1120+
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
1121+
}
1122+
10341123
#[cfg(all(feature = "transport-streamable-http-server", feature = "server"))]
10351124
mod origin_validation {
10361125
use std::sync::Arc;

0 commit comments

Comments
 (0)