Skip to content

Commit efc224c

Browse files
committed
fix: pr review comments
1 parent b241cd3 commit efc224c

File tree

1 file changed

+39
-36
lines changed
  • crates/rmcp/src/transport/streamable_http_server

1 file changed

+39
-36
lines changed

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::{
3434
},
3535
};
3636

37-
#[derive(Clone)]
37+
#[derive(Debug, Clone)]
3838
pub struct StreamableHttpServerConfig {
3939
/// The ping message duration for SSE connections.
4040
pub sse_keep_alive: Option<Duration>,
@@ -77,19 +77,9 @@ pub struct StreamableHttpServerConfig {
7777
pub session_store: Option<Arc<dyn SessionStore>>,
7878
}
7979

80-
impl std::fmt::Debug for StreamableHttpServerConfig {
80+
impl std::fmt::Debug for dyn SessionStore {
8181
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82-
f.debug_struct("StreamableHttpServerConfig")
83-
.field("sse_keep_alive", &self.sse_keep_alive)
84-
.field("sse_retry", &self.sse_retry)
85-
.field("stateful_mode", &self.stateful_mode)
86-
.field("json_response", &self.json_response)
87-
.field("cancellation_token", &self.cancellation_token)
88-
.field(
89-
"session_store",
90-
&self.session_store.as_ref().map(|_| "<SessionStore>"),
91-
)
92-
.finish()
82+
f.write_str("<SessionStore>")
9383
}
9484
}
9585

@@ -279,6 +269,35 @@ where
279269
}
280270
}
281271

272+
/// Guard used inside [`StreamableHttpService::try_restore_from_store`].
273+
///
274+
/// Ensures the `pending_restores` map entry is always cleaned up — even when
275+
/// the future is cancelled mid-await.
276+
///
277+
/// `result` defaults to `false` (failure / cancellation). Only the success path
278+
/// needs to set it to `true` before returning.
279+
struct PendingRestoreGuard {
280+
pending_restores:
281+
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
282+
session_id: SessionId,
283+
watch_tx: tokio::sync::watch::Sender<Option<bool>>,
284+
/// The value that will be broadcast to waiting tasks on drop.
285+
result: bool,
286+
}
287+
288+
impl Drop for PendingRestoreGuard {
289+
fn drop(&mut self) {
290+
// `send` is synchronous — unblocks waiters immediately, no lock needed.
291+
let _ = self.watch_tx.send(Some(self.result));
292+
// Remove the map entry asynchronously (requires the async write lock).
293+
let pending_restores = self.pending_restores.clone();
294+
let session_id = self.session_id.clone();
295+
tokio::spawn(async move {
296+
pending_restores.write().await.remove(&session_id);
297+
});
298+
}
299+
}
300+
282301
impl<S, M> StreamableHttpService<S, M>
283302
where
284303
S: crate::Service<RoleServer> + Send + 'static,
@@ -394,26 +413,18 @@ where
394413
pending.insert(session_id.clone(), watch_tx.clone());
395414
}
396415

397-
// Helper: signal waiters with the outcome and remove from the pending map.
398-
let finish = {
399-
let pending_restores = pending_restores.clone();
400-
let session_id = session_id.clone();
401-
move |result: bool| {
402-
let pending_restores = pending_restores.clone();
403-
let session_id = session_id.clone();
404-
tokio::spawn(async move {
405-
if let Some(tx) = pending_restores.write().await.remove(&session_id) {
406-
let _ = tx.send(Some(result));
407-
}
408-
});
409-
}
416+
// Guard: signals waiters and cleans up the map entry on drop
417+
let mut guard = PendingRestoreGuard {
418+
pending_restores: pending_restores.clone(),
419+
session_id: session_id.clone(),
420+
watch_tx: watch_tx.clone(),
421+
result: false,
410422
};
411423

412424
// --- Step 3: load from external store ---
413425
let state = match store.load(session_id.as_ref()).await {
414426
Ok(Some(s)) => s,
415427
Ok(None) => {
416-
finish(false);
417428
return Ok(false);
418429
}
419430
Err(e) => {
@@ -422,7 +433,6 @@ where
422433
error = %e,
423434
"session store load failed during restore"
424435
);
425-
finish(false);
426436
return Err(std::io::Error::other(e));
427437
}
428438
};
@@ -438,17 +448,14 @@ where
438448
Ok(RestoreOutcome::AlreadyPresent) => {
439449
// Invariant violation: pending_restores ensures only one task can call
440450
// restore_session per session ID, so AlreadyPresent is impossible here.
441-
finish(false);
442451
return Err(std::io::Error::other(
443452
"restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API",
444453
));
445454
}
446455
Ok(RestoreOutcome::NotSupported) => {
447-
finish(false);
448456
return Ok(false);
449457
}
450458
Err(e) => {
451-
finish(false);
452459
return Err(e);
453460
}
454461
};
@@ -457,7 +464,6 @@ where
457464
let service = match self.get_service() {
458465
Ok(s) => s,
459466
Err(e) => {
460-
finish(false);
461467
return Err(e);
462468
}
463469
};
@@ -502,7 +508,6 @@ where
502508
.await
503509
.map_err(|e| std::io::Error::other(e.to_string()))
504510
{
505-
finish(false);
506511
return Err(e);
507512
}
508513

@@ -512,19 +517,17 @@ where
512517
.await
513518
.map_err(|e| std::io::Error::other(e.to_string()))
514519
{
515-
finish(false);
516520
return Err(e);
517521
}
518522

519523
if init_done_rx.await.is_err() {
520-
finish(false);
521524
return Err(std::io::Error::other(
522525
"serve_server initialization failed during restore",
523526
));
524527
}
525528

526529
// Restore complete — wake any waiting concurrent requests.
527-
finish(true);
530+
guard.result = true;
528531

529532
tracing::debug!(
530533
session_id = session_id.as_ref(),

0 commit comments

Comments
 (0)