Skip to content

Commit 6a071c1

Browse files
committed
fix: pr review comments
1 parent 5f112f6 commit 6a071c1

File tree

1 file changed

+38
-35
lines changed
  • crates/rmcp/src/transport/streamable_http_server

1 file changed

+38
-35
lines changed

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,9 @@ pub struct StreamableHttpServerConfig {
7878
pub session_store: Option<Arc<dyn SessionStore>>,
7979
}
8080

81-
impl std::fmt::Debug for StreamableHttpServerConfig {
81+
impl std::fmt::Debug for dyn SessionStore {
8282
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83-
f.debug_struct("StreamableHttpServerConfig")
84-
.field("sse_keep_alive", &self.sse_keep_alive)
85-
.field("sse_retry", &self.sse_retry)
86-
.field("stateful_mode", &self.stateful_mode)
87-
.field("json_response", &self.json_response)
88-
.field("cancellation_token", &self.cancellation_token)
89-
.field(
90-
"session_store",
91-
&self.session_store.as_ref().map(|_| "<SessionStore>"),
92-
)
93-
.finish()
83+
f.write_str("<SessionStore>")
9484
}
9585
}
9686

@@ -307,6 +297,35 @@ where
307297
}
308298
}
309299

300+
/// Guard used inside [`StreamableHttpService::try_restore_from_store`].
301+
///
302+
/// Ensures the `pending_restores` map entry is always cleaned up — even when
303+
/// the future is cancelled mid-await.
304+
///
305+
/// `result` defaults to `false` (failure / cancellation). Only the success path
306+
/// needs to set it to `true` before returning.
307+
struct PendingRestoreGuard {
308+
pending_restores:
309+
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
310+
session_id: SessionId,
311+
watch_tx: tokio::sync::watch::Sender<Option<bool>>,
312+
/// The value that will be broadcast to waiting tasks on drop.
313+
result: bool,
314+
}
315+
316+
impl Drop for PendingRestoreGuard {
317+
fn drop(&mut self) {
318+
// `send` is synchronous — unblocks waiters immediately, no lock needed.
319+
let _ = self.watch_tx.send(Some(self.result));
320+
// Remove the map entry asynchronously (requires the async write lock).
321+
let pending_restores = self.pending_restores.clone();
322+
let session_id = self.session_id.clone();
323+
tokio::spawn(async move {
324+
pending_restores.write().await.remove(&session_id);
325+
});
326+
}
327+
}
328+
310329
impl<S, M> StreamableHttpService<S, M>
311330
where
312331
S: crate::Service<RoleServer> + Send + 'static,
@@ -422,26 +441,18 @@ where
422441
pending.insert(session_id.clone(), watch_tx.clone());
423442
}
424443

425-
// Helper: signal waiters with the outcome and remove from the pending map.
426-
let finish = {
427-
let pending_restores = pending_restores.clone();
428-
let session_id = session_id.clone();
429-
move |result: bool| {
430-
let pending_restores = pending_restores.clone();
431-
let session_id = session_id.clone();
432-
tokio::spawn(async move {
433-
if let Some(tx) = pending_restores.write().await.remove(&session_id) {
434-
let _ = tx.send(Some(result));
435-
}
436-
});
437-
}
444+
// Guard: signals waiters and cleans up the map entry on drop
445+
let mut guard = PendingRestoreGuard {
446+
pending_restores: pending_restores.clone(),
447+
session_id: session_id.clone(),
448+
watch_tx: watch_tx.clone(),
449+
result: false,
438450
};
439451

440452
// --- Step 3: load from external store ---
441453
let state = match store.load(session_id.as_ref()).await {
442454
Ok(Some(s)) => s,
443455
Ok(None) => {
444-
finish(false);
445456
return Ok(false);
446457
}
447458
Err(e) => {
@@ -450,7 +461,6 @@ where
450461
error = %e,
451462
"session store load failed during restore"
452463
);
453-
finish(false);
454464
return Err(std::io::Error::other(e));
455465
}
456466
};
@@ -466,17 +476,14 @@ where
466476
Ok(RestoreOutcome::AlreadyPresent) => {
467477
// Invariant violation: pending_restores ensures only one task can call
468478
// restore_session per session ID, so AlreadyPresent is impossible here.
469-
finish(false);
470479
return Err(std::io::Error::other(
471480
"restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API",
472481
));
473482
}
474483
Ok(RestoreOutcome::NotSupported) => {
475-
finish(false);
476484
return Ok(false);
477485
}
478486
Err(e) => {
479-
finish(false);
480487
return Err(e);
481488
}
482489
};
@@ -485,7 +492,6 @@ where
485492
let service = match self.get_service() {
486493
Ok(s) => s,
487494
Err(e) => {
488-
finish(false);
489495
return Err(e);
490496
}
491497
};
@@ -530,7 +536,6 @@ where
530536
.await
531537
.map_err(|e| std::io::Error::other(e.to_string()))
532538
{
533-
finish(false);
534539
return Err(e);
535540
}
536541

@@ -540,19 +545,17 @@ where
540545
.await
541546
.map_err(|e| std::io::Error::other(e.to_string()))
542547
{
543-
finish(false);
544548
return Err(e);
545549
}
546550

547551
if init_done_rx.await.is_err() {
548-
finish(false);
549552
return Err(std::io::Error::other(
550553
"serve_server initialization failed during restore",
551554
));
552555
}
553556

554557
// Restore complete — wake any waiting concurrent requests.
555-
finish(true);
558+
guard.result = true;
556559

557560
tracing::debug!(
558561
session_id = session_id.as_ref(),

0 commit comments

Comments
 (0)