Skip to content

Commit b881409

Browse files
committed
fix: include http_request_id in request-wise priming event IDs
1 parent 65d2b29 commit b881409

3 files changed

Lines changed: 168 additions & 20 deletions

File tree

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::{
44
time::Duration,
55
};
66

7-
use futures::Stream;
7+
use futures::{Stream, StreamExt};
88
use thiserror::Error;
99
use tokio::sync::{
1010
mpsc::{Receiver, Sender},
@@ -86,10 +86,20 @@ impl SessionManager for LocalSessionManager {
8686
.get(id)
8787
.ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
8888
let receiver = handle.establish_request_wise_channel().await?;
89-
handle
90-
.push_message(message, receiver.http_request_id)
91-
.await?;
92-
Ok(ReceiverStream::new(receiver.inner))
89+
let http_request_id = receiver.http_request_id;
90+
handle.push_message(message, http_request_id).await?;
91+
92+
let priming_events: Vec<ServerSseMessage> = match self.session_config.sse_retry {
93+
Some(retry) => {
94+
let event_id = match http_request_id {
95+
Some(id) => format!("0/{id}"),
96+
None => "0".into(),
97+
};
98+
vec![ServerSseMessage::priming(event_id, retry)]
99+
}
100+
None => vec![],
101+
};
102+
Ok(futures::stream::iter(priming_events).chain(ReceiverStream::new(receiver.inner)))
93103
}
94104

95105
async fn create_standalone_stream(
@@ -188,23 +198,29 @@ struct CachedTx {
188198
cache: VecDeque<ServerSseMessage>,
189199
http_request_id: Option<HttpRequestId>,
190200
capacity: usize,
201+
starting_index: usize,
191202
}
192203

193204
impl CachedTx {
194-
fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
205+
fn new(
206+
tx: Sender<ServerSseMessage>,
207+
http_request_id: Option<HttpRequestId>,
208+
starting_index: usize,
209+
) -> Self {
195210
Self {
196211
cache: VecDeque::with_capacity(tx.capacity()),
197212
capacity: tx.capacity(),
198213
tx,
199214
http_request_id,
215+
starting_index,
200216
}
201217
}
202218
fn new_common(tx: Sender<ServerSseMessage>) -> Self {
203-
Self::new(tx, None)
219+
Self::new(tx, None, 0)
204220
}
205221

206222
fn next_event_id(&self) -> EventId {
207-
let index = self.cache.back().map_or(0, |m| {
223+
let index = self.cache.back().map_or(self.starting_index, |m| {
208224
m.event_id
209225
.as_deref()
210226
.unwrap_or_default()
@@ -405,11 +421,16 @@ impl LocalSessionWorker {
405421
) -> Result<StreamableHttpMessageReceiver, SessionError> {
406422
let http_request_id = self.next_http_request_id();
407423
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
424+
let starting_index = if self.session_config.sse_retry.is_some() {
425+
1
426+
} else {
427+
0
428+
};
408429
self.tx_router.insert(
409430
http_request_id,
410431
HttpRequestWise {
411432
resources: Default::default(),
412-
tx: CachedTx::new(tx, Some(http_request_id)),
433+
tx: CachedTx::new(tx, Some(http_request_id), starting_index),
413434
},
414435
);
415436
tracing::debug!(http_request_id, "establish new request wise channel");
@@ -1072,18 +1093,25 @@ pub struct SessionConfig {
10721093
/// Defaults to 5 minutes. Set to `None` to disable (not recommended
10731094
/// for long-running servers behind proxies).
10741095
pub keep_alive: Option<Duration>,
1096+
/// SSE retry interval for priming events on request-wise streams.
1097+
/// When set, the session layer prepends a priming event with the correct
1098+
/// stream-identifying event ID to each request-wise SSE stream.
1099+
/// Default is 3 seconds, matching [`StreamableHttpServerConfig::default()`].
1100+
pub sse_retry: Option<Duration>,
10751101
}
10761102

10771103
impl SessionConfig {
10781104
pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
10791105
pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1106+
pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
10801107
}
10811108

10821109
impl Default for SessionConfig {
10831110
fn default() -> Self {
10841111
Self {
10851112
channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
10861113
keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1114+
sse_retry: Some(Self::DEFAULT_SSE_RETRY),
10871115
}
10881116
}
10891117
}

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -598,20 +598,14 @@ where
598598

599599
match message {
600600
ClientJsonRpcMessage::Request(_) => {
601+
// Priming for request-wise streams is handled by the
602+
// session layer (SessionManager::create_stream) which
603+
// has access to the http_request_id for correct event IDs.
601604
let stream = self
602605
.session_manager
603606
.create_stream(&session_id, message)
604607
.await
605608
.map_err(internal_error_response("get session"))?;
606-
// Prepend priming event if sse_retry configured
607-
let stream = if let Some(retry) = self.config.sse_retry {
608-
let priming = ServerSseMessage::priming("0", retry);
609-
futures::stream::once(async move { priming })
610-
.chain(stream)
611-
.left_stream()
612-
} else {
613-
stream.right_stream()
614-
};
615609
Ok(sse_stream_response(
616610
stream,
617611
self.config.sse_keep_alive,

crates/rmcp/tests/test_streamable_http_priming.rs

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
use std::time::Duration;
33

44
use rmcp::transport::streamable_http_server::{
5-
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
5+
StreamableHttpServerConfig, StreamableHttpService,
6+
session::{SessionId, local::LocalSessionManager},
67
};
78
use tokio_util::sync::CancellationToken;
89

@@ -54,7 +55,7 @@ async fn test_priming_on_stream_start() -> anyhow::Result<()> {
5455
let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect();
5556
assert!(events.len() >= 2);
5657

57-
// Verify priming event (first event)
58+
// Verify priming event (first event) — initialize uses "0" (no http_request_id)
5859
let priming_event = events[0];
5960
assert!(priming_event.contains("id: 0"));
6061
assert!(priming_event.contains("retry: 3000"));
@@ -71,6 +72,131 @@ async fn test_priming_on_stream_start() -> anyhow::Result<()> {
7172
Ok(())
7273
}
7374

75+
#[tokio::test]
76+
async fn test_request_wise_priming_includes_http_request_id() -> anyhow::Result<()> {
77+
let ct = CancellationToken::new();
78+
79+
let service: StreamableHttpService<Calculator, LocalSessionManager> =
80+
StreamableHttpService::new(
81+
|| Ok(Calculator::new()),
82+
Default::default(),
83+
StreamableHttpServerConfig::default()
84+
.with_sse_keep_alive(None)
85+
.with_cancellation_token(ct.child_token()),
86+
);
87+
88+
let router = axum::Router::new().nest_service("/mcp", service);
89+
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
90+
let addr = tcp_listener.local_addr()?;
91+
92+
let handle = tokio::spawn({
93+
let ct = ct.clone();
94+
async move {
95+
let _ = axum::serve(tcp_listener, router)
96+
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
97+
.await;
98+
}
99+
});
100+
101+
let client = reqwest::Client::new();
102+
103+
// Initialize the session
104+
let response = client
105+
.post(format!("http://{addr}/mcp"))
106+
.header("Content-Type", "application/json")
107+
.header("Accept", "application/json, text/event-stream")
108+
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#)
109+
.send()
110+
.await?;
111+
assert_eq!(response.status(), 200);
112+
let session_id: SessionId = response.headers()["mcp-session-id"].to_str()?.into();
113+
114+
// Send notifications/initialized
115+
let status = client
116+
.post(format!("http://{addr}/mcp"))
117+
.header("Content-Type", "application/json")
118+
.header("Accept", "application/json, text/event-stream")
119+
.header("mcp-session-id", session_id.to_string())
120+
.header("Mcp-Protocol-Version", "2025-06-18")
121+
.body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
122+
.send()
123+
.await?
124+
.status();
125+
assert_eq!(status, 202);
126+
127+
// First tool call — should get http_request_id 0
128+
let body = client
129+
.post(format!("http://{addr}/mcp"))
130+
.header("Content-Type", "application/json")
131+
.header("Accept", "application/json, text/event-stream")
132+
.header("mcp-session-id", session_id.to_string())
133+
.header("Mcp-Protocol-Version", "2025-06-18")
134+
.body(r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"sum","arguments":{"a":1,"b":2}}}"#)
135+
.send()
136+
.await?
137+
.text()
138+
.await?;
139+
140+
let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect();
141+
assert!(
142+
events.len() >= 2,
143+
"expected priming + response, got: {body}"
144+
);
145+
146+
// Priming event should encode the http_request_id (0)
147+
let priming = events[0];
148+
assert!(
149+
priming.contains("id: 0/0"),
150+
"first request priming should be 0/0, got: {priming}"
151+
);
152+
assert!(priming.contains("retry: 3000"));
153+
154+
// Response event should use index 1 (since priming occupies index 0)
155+
let response_event = events[1];
156+
assert!(
157+
response_event.contains("id: 1/0"),
158+
"first response event id should be 1/0, got: {response_event}"
159+
);
160+
assert!(response_event.contains(r#""id":2"#));
161+
162+
// Second tool call — should get http_request_id 1
163+
let body = client
164+
.post(format!("http://{addr}/mcp"))
165+
.header("Content-Type", "application/json")
166+
.header("Accept", "application/json, text/event-stream")
167+
.header("mcp-session-id", session_id.to_string())
168+
.header("Mcp-Protocol-Version", "2025-06-18")
169+
.body(r#"{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"sum","arguments":{"a":3,"b":4}}}"#)
170+
.send()
171+
.await?
172+
.text()
173+
.await?;
174+
175+
let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect();
176+
assert!(
177+
events.len() >= 2,
178+
"expected priming + response, got: {body}"
179+
);
180+
181+
let priming = events[0];
182+
assert!(
183+
priming.contains("id: 0/1"),
184+
"second request priming should be 0/1, got: {priming}"
185+
);
186+
187+
let response_event = events[1];
188+
assert!(
189+
response_event.contains("id: 1/1"),
190+
"second response event id should be 1/1, got: {response_event}"
191+
);
192+
assert!(response_event.contains(r#""id":3"#));
193+
194+
ct.cancel();
195+
handle.await?;
196+
197+
Ok(())
198+
}
199+
74200
#[tokio::test]
75201
async fn test_priming_on_stream_close() -> anyhow::Result<()> {
76202
use std::sync::Arc;

0 commit comments

Comments
 (0)