Skip to content

Commit ee1c63c

Browse files
authored
feat(transport): add Unix domain socket client for streamable HTTP (#749)
* feat(transport): add Unix domain socket client for streamable HTTP MCP hosts in Kubernetes environments with Envoy sidecars need to route HTTP through Unix domain sockets because DNS-based URIs only resolve via the proxy. Adds UnixSocketHttpClient implementing StreamableHttpClient using hyper over tokio::net::UnixStream, gated behind the transport-streamable-http-client-unix-socket feature. Also extracts RESERVED_HEADERS, extract_scope_from_header, and validate_custom_header into common/http_header.rs to share header validation logic between the reqwest and unix socket implementations. * fix(transport): address review feedback for unix socket transport - Document one-connection-per-request behavior on UnixSocketHttpClient - Reject empty socket paths and bare '@' in constructor with assert - Add explicit dep:http to unix-socket feature for self-documenting deps - Document MCP-Protocol-Version exception on RESERVED_HEADERS constant - Fix test catch-all to echo request id instead of hardcoding 1 - Remove leftover sleep(100ms) in test_unix_socket_custom_headers - Add blank line before macro comment in Cargo.toml * fix(transport): fix CI failures for unix socket transport - Use std::io::Error::other() instead of Error::new(ErrorKind::Other) to satisfy clippy::io_other_error on newer nightly - Use #[tokio::test(flavor = "current_thread")] for unix socket tests since axum's serve(UnixListener) requires spawn_local - Gate validate_custom_header behind client-side-sse feature since it references http::HeaderName which isn't available with default features * fix(transport): fix CI failures for unix socket transport axum::serve(UnixListener) uses spawn_local on Linux, which panics outside a LocalSet. Replace with manual hyper HTTP/1.1 server that accepts connections directly from the UnixListener, avoiding the spawn_local requirement entirely. * fix(transport): skip unix socket tests when local feature is enabled The local feature causes ().serve(transport) to use spawn_local, which requires a LocalSet. Gate the integration tests with not(feature = "local") to match every other integration test in the repo.
1 parent a32a9c8 commit ee1c63c

File tree

7 files changed

+998
-85
lines changed

7 files changed

+998
-85
lines changed

crates/rmcp/Cargo.toml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ uuid = { version = "1", features = ["v4"], optional = true }
6262
http-body = { version = "1", optional = true }
6363
http-body-util = { version = "0.1", optional = true }
6464
bytes = { version = "1", optional = true }
65+
66+
# for unix socket transport
67+
hyper = { version = "1", features = ["client", "http1"], optional = true }
68+
hyper-util = { version = "0.1", features = ["tokio"], optional = true }
69+
6570
# macro
6671
rmcp-macros = { workspace = true, optional = true }
6772
[target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies]
@@ -112,6 +117,15 @@ client-side-sse = ["dep:sse-stream", "dep:http"]
112117
# Streamable HTTP client
113118
transport-streamable-http-client = ["client-side-sse", "transport-worker"]
114119
transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"]
120+
transport-streamable-http-client-unix-socket = [
121+
"transport-streamable-http-client",
122+
"dep:hyper",
123+
"dep:hyper-util",
124+
"dep:http-body-util",
125+
"dep:http",
126+
"dep:bytes",
127+
"tokio/net",
128+
]
115129

116130
transport-async-rw = ["tokio/io-util", "tokio-util/codec"]
117131
transport-io = ["transport-async-rw", "tokio/io-std"]
@@ -139,6 +153,9 @@ schemars = ["dep:schemars"]
139153
tokio = { version = "1", features = ["full"] }
140154
schemars = { version = "1.1.0", features = ["chrono04"] }
141155
axum = { version = "0.8", default-features = false, features = ["http1", "tokio"] }
156+
hyper = { version = "1", features = ["server", "http1"] }
157+
hyper-util = { version = "0.1", features = ["tokio"] }
158+
tower-service = "0.3"
142159
url = "2.4"
143160
anyhow = "1.0"
144161
tracing-subscriber = { version = "0.3", features = [
@@ -266,6 +283,15 @@ name = "test_client_credentials"
266283
required-features = ["auth"]
267284
path = "tests/test_client_credentials.rs"
268285

286+
[[test]]
287+
name = "test_unix_socket_transport"
288+
required-features = [
289+
"client",
290+
"server",
291+
"transport-streamable-http-client-unix-socket",
292+
]
293+
path = "tests/test_unix_socket_transport.rs"
294+
269295
[[test]]
270296
name = "test_streamable_http_stale_session"
271297
required-features = [

crates/rmcp/src/transport.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHt
112112

113113
#[cfg(feature = "transport-streamable-http-client")]
114114
pub mod streamable_http_client;
115+
#[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))]
116+
pub use common::unix_socket::UnixSocketHttpClient;
115117
#[cfg(feature = "transport-streamable-http-client")]
116118
pub use streamable_http_client::StreamableHttpClientTransport;
117119

crates/rmcp/src/transport/common.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ pub mod client_side_sse;
1414

1515
#[cfg(feature = "auth")]
1616
pub mod auth;
17+
18+
#[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))]
19+
pub mod unix_socket;

crates/rmcp/src/transport/common/http_header.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,122 @@ pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
33
pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version";
44
pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
55
pub const JSON_MIME_TYPE: &str = "application/json";
6+
7+
/// Reserved headers that must not be overridden by user-supplied custom headers.
8+
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
9+
/// injects it after initialization.
10+
pub(crate) const RESERVED_HEADERS: &[&str] = &[
11+
"accept",
12+
HEADER_SESSION_ID,
13+
HEADER_MCP_PROTOCOL_VERSION, // allowed through by validate_custom_header; worker injects it post-init
14+
HEADER_LAST_EVENT_ID,
15+
];
16+
17+
/// Checks whether a custom header name is allowed.
18+
/// Returns `Ok(())` if allowed, `Err(name)` if rejected as reserved.
19+
/// `MCP-Protocol-Version` is reserved but allowed through (the worker injects it post-init).
20+
#[cfg(feature = "client-side-sse")]
21+
pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), String> {
22+
if RESERVED_HEADERS
23+
.iter()
24+
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
25+
{
26+
if name
27+
.as_str()
28+
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
29+
{
30+
return Ok(());
31+
}
32+
return Err(name.to_string());
33+
}
34+
Ok(())
35+
}
36+
37+
/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value.
38+
/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms.
39+
pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {
40+
let header_lowercase = header.to_ascii_lowercase();
41+
let scope_key = "scope=";
42+
43+
if let Some(pos) = header_lowercase.find(scope_key) {
44+
let start = pos + scope_key.len();
45+
let value_slice = &header[start..];
46+
47+
if let Some(stripped) = value_slice.strip_prefix('"') {
48+
if let Some(end_quote) = stripped.find('"') {
49+
return Some(stripped[..end_quote].to_string());
50+
}
51+
} else {
52+
let end = value_slice
53+
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
54+
.unwrap_or(value_slice.len());
55+
if end > 0 {
56+
return Some(value_slice[..end].to_string());
57+
}
58+
}
59+
}
60+
61+
None
62+
}
63+
64+
#[cfg(test)]
65+
mod tests {
66+
use super::*;
67+
68+
#[test]
69+
fn extract_scope_quoted() {
70+
let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#;
71+
assert_eq!(
72+
extract_scope_from_header(header),
73+
Some("files:read files:write".to_string())
74+
);
75+
}
76+
77+
#[test]
78+
fn extract_scope_unquoted() {
79+
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
80+
assert_eq!(
81+
extract_scope_from_header(header),
82+
Some("read:data".to_string())
83+
);
84+
}
85+
86+
#[test]
87+
fn extract_scope_missing() {
88+
let header = r#"Bearer error="invalid_token""#;
89+
assert_eq!(extract_scope_from_header(header), None);
90+
}
91+
92+
#[test]
93+
fn extract_scope_empty_header() {
94+
assert_eq!(extract_scope_from_header("Bearer"), None);
95+
}
96+
97+
#[cfg(feature = "client-side-sse")]
98+
#[test]
99+
fn validate_rejects_reserved_accept() {
100+
let name = http::HeaderName::from_static("accept");
101+
assert!(validate_custom_header(&name).is_err());
102+
}
103+
104+
#[cfg(feature = "client-side-sse")]
105+
#[test]
106+
fn validate_rejects_reserved_session_id() {
107+
let name = http::HeaderName::from_static("mcp-session-id");
108+
assert!(validate_custom_header(&name).is_err());
109+
}
110+
111+
#[cfg(feature = "client-side-sse")]
112+
#[test]
113+
fn validate_allows_mcp_protocol_version() {
114+
let name = http::HeaderName::from_static("mcp-protocol-version");
115+
assert!(validate_custom_header(&name).is_ok());
116+
}
117+
118+
#[cfg(feature = "client-side-sse")]
119+
#[test]
120+
fn validate_allows_custom_header() {
121+
let name = http::HeaderName::from_static("x-custom");
122+
assert!(validate_custom_header(&name).is_ok());
123+
}
124+
}

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 5 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use crate::{
99
model::{ClientJsonRpcMessage, JsonRpcMessage, ServerJsonRpcMessage},
1010
transport::{
1111
common::http_header::{
12-
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
13-
HEADER_SESSION_ID, JSON_MIME_TYPE,
12+
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
13+
extract_scope_from_header, validate_custom_header,
1414
},
1515
streamable_http_client::*,
1616
},
@@ -22,38 +22,13 @@ impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
2222
}
2323
}
2424

25-
/// Reserved headers that must not be overridden by user-supplied custom headers.
26-
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
27-
/// injects it after initialization.
28-
const RESERVED_HEADERS: &[&str] = &[
29-
"accept",
30-
HEADER_SESSION_ID,
31-
HEADER_MCP_PROTOCOL_VERSION,
32-
HEADER_LAST_EVENT_ID,
33-
];
34-
35-
/// Applies custom headers to a request builder, rejecting reserved headers
36-
/// except `MCP-Protocol-Version` (which the worker injects after init).
25+
/// Applies custom headers to a request builder, rejecting reserved headers.
3726
fn apply_custom_headers(
3827
mut builder: reqwest::RequestBuilder,
3928
custom_headers: HashMap<HeaderName, HeaderValue>,
4029
) -> Result<reqwest::RequestBuilder, StreamableHttpError<reqwest::Error>> {
4130
for (name, value) in custom_headers {
42-
if RESERVED_HEADERS
43-
.iter()
44-
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
45-
{
46-
if name
47-
.as_str()
48-
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
49-
{
50-
builder = builder.header(name, value);
51-
continue;
52-
}
53-
return Err(StreamableHttpError::ReservedHeaderConflict(
54-
name.to_string(),
55-
));
56-
}
31+
validate_custom_header(&name).map_err(StreamableHttpError::ReservedHeaderConflict)?;
5732
builder = builder.header(name, value);
5833
}
5934
Ok(builder)
@@ -306,66 +281,11 @@ impl StreamableHttpClientTransport<reqwest::Client> {
306281
}
307282
}
308283

309-
/// extract scope parameter from WWW-Authenticate header
310-
fn extract_scope_from_header(header: &str) -> Option<String> {
311-
let header_lowercase = header.to_ascii_lowercase();
312-
let scope_key = "scope=";
313-
314-
if let Some(pos) = header_lowercase.find(scope_key) {
315-
let start = pos + scope_key.len();
316-
let value_slice = &header[start..];
317-
318-
if let Some(stripped) = value_slice.strip_prefix('"') {
319-
if let Some(end_quote) = stripped.find('"') {
320-
return Some(stripped[..end_quote].to_string());
321-
}
322-
} else {
323-
let end = value_slice
324-
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
325-
.unwrap_or(value_slice.len());
326-
if end > 0 {
327-
return Some(value_slice[..end].to_string());
328-
}
329-
}
330-
}
331-
332-
None
333-
}
334-
335284
#[cfg(test)]
336285
mod tests {
337-
use super::{extract_scope_from_header, parse_json_rpc_error};
286+
use super::parse_json_rpc_error;
338287
use crate::{model::JsonRpcMessage, transport::streamable_http_client::InsufficientScopeError};
339288

340-
#[test]
341-
fn extract_scope_quoted() {
342-
let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#;
343-
assert_eq!(
344-
extract_scope_from_header(header),
345-
Some("files:read files:write".to_string())
346-
);
347-
}
348-
349-
#[test]
350-
fn extract_scope_unquoted() {
351-
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
352-
assert_eq!(
353-
extract_scope_from_header(header),
354-
Some("read:data".to_string())
355-
);
356-
}
357-
358-
#[test]
359-
fn extract_scope_missing() {
360-
let header = r#"Bearer error="invalid_token""#;
361-
assert_eq!(extract_scope_from_header(header), None);
362-
}
363-
364-
#[test]
365-
fn extract_scope_empty_header() {
366-
assert_eq!(extract_scope_from_header("Bearer"), None);
367-
}
368-
369289
#[test]
370290
fn insufficient_scope_error_can_upgrade() {
371291
let with_scope = InsufficientScopeError {

0 commit comments

Comments
 (0)