Skip to content

Commit 30ac65c

Browse files
committed
feat(router): auto-send tools/list_changed on disable/enable
1 parent 16f8b8c commit 30ac65c

4 files changed

Lines changed: 450 additions & 25 deletions

File tree

crates/rmcp/src/handler/server/router.rs

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tool::{IntoToolRoute, ToolRoute};
66
use super::ServerHandler;
77
use crate::{
88
RoleServer, Service,
9-
model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
9+
model::{ClientNotification, ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
1010
service::NotificationContext,
1111
};
1212

@@ -18,17 +18,22 @@ pub struct Router<S> {
1818
pub tool_router: tool::ToolRouter<S>,
1919
pub prompt_router: prompt::PromptRouter<S>,
2020
pub service: Arc<S>,
21+
peer_slot: Arc<std::sync::OnceLock<crate::service::Peer<RoleServer>>>,
2122
}
2223

2324
impl<S> Router<S>
2425
where
2526
S: ServerHandler,
2627
{
2728
pub fn new(service: S) -> Self {
29+
let (notifier, peer_slot) = tool::ToolRouter::<S>::deferred_peer_notifier();
30+
let mut tool_router = tool::ToolRouter::new();
31+
tool_router.set_notifier(notifier);
2832
Self {
29-
tool_router: tool::ToolRouter::new(),
33+
tool_router,
3034
prompt_router: prompt::PromptRouter::new(),
3135
service: Arc::new(service),
36+
peer_slot,
3237
}
3338
}
3439

@@ -72,6 +77,12 @@ where
7277
notification: <RoleServer as crate::service::ServiceRole>::PeerNot,
7378
context: NotificationContext<RoleServer>,
7479
) -> Result<(), crate::ErrorData> {
80+
if matches!(
81+
&notification,
82+
ClientNotification::InitializedNotification(_)
83+
) {
84+
let _ = self.peer_slot.set(context.peer.clone());
85+
}
7586
self.service
7687
.handle_notification(notification, context)
7788
.await
@@ -137,6 +148,81 @@ where
137148
}
138149

139150
fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
140-
ServerHandler::get_info(&self.service)
151+
let mut info = ServerHandler::get_info(&self.service);
152+
info.capabilities
153+
.tools
154+
.get_or_insert_with(Default::default)
155+
.list_changed = Some(true);
156+
info
157+
}
158+
}
159+
160+
#[cfg(test)]
161+
mod tests {
162+
use std::sync::Arc;
163+
164+
use super::*;
165+
use crate::{
166+
model::{CallToolResult, ClientNotification, ServerNotification, Tool},
167+
service::{AtomicU32RequestIdProvider, Peer, PeerSinkMessage, RequestIdProvider},
168+
};
169+
170+
struct DummyHandler;
171+
impl ServerHandler for DummyHandler {}
172+
173+
async fn recv_notification(
174+
rx: &mut tokio::sync::mpsc::Receiver<PeerSinkMessage<RoleServer>>,
175+
) -> ServerNotification {
176+
let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
177+
.await
178+
.expect("timed out")
179+
.expect("channel closed");
180+
match msg {
181+
PeerSinkMessage::Notification {
182+
notification,
183+
responder,
184+
} => {
185+
let _ = responder.send(Ok(()));
186+
notification
187+
}
188+
other => panic!("expected notification, got {other:?}"),
189+
}
190+
}
191+
192+
#[tokio::test]
193+
async fn test_router_deferred_notifier_e2e() {
194+
let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn(
195+
Tool::new("my_tool", "test", Arc::new(Default::default())),
196+
|_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
197+
));
198+
199+
let id_provider: Arc<dyn RequestIdProvider> =
200+
Arc::new(AtomicU32RequestIdProvider::default());
201+
let (peer, mut rx) = Peer::<RoleServer>::new(id_provider, None);
202+
203+
let context = crate::service::NotificationContext {
204+
peer: peer.clone(),
205+
meta: Default::default(),
206+
extensions: Default::default(),
207+
};
208+
router
209+
.handle_notification(
210+
ClientNotification::InitializedNotification(Default::default()),
211+
context,
212+
)
213+
.await
214+
.unwrap();
215+
216+
router.tool_router.disable_route("my_tool");
217+
assert!(matches!(
218+
recv_notification(&mut rx).await,
219+
ServerNotification::ToolListChangedNotification(_)
220+
));
221+
222+
router.tool_router.enable_route("my_tool");
223+
assert!(matches!(
224+
recv_notification(&mut rx).await,
225+
ServerNotification::ToolListChangedNotification(_)
226+
));
141227
}
142228
}

crates/rmcp/src/handler/server/router/tool.rs

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ where
298298
self
299299
}
300300
}
301-
#[derive(Debug)]
302301
#[non_exhaustive]
303302
pub struct ToolRouter<S> {
304303
#[allow(clippy::type_complexity)]
@@ -307,6 +306,22 @@ pub struct ToolRouter<S> {
307306
pub transparent_when_not_found: bool,
308307

309308
disabled: std::collections::HashSet<Cow<'static, str>>,
309+
310+
notifier: Option<Arc<dyn Fn() + Send + Sync>>,
311+
}
312+
313+
impl<S> std::fmt::Debug for ToolRouter<S> {
314+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315+
f.debug_struct("ToolRouter")
316+
.field("map", &self.map)
317+
.field(
318+
"transparent_when_not_found",
319+
&self.transparent_when_not_found,
320+
)
321+
.field("disabled", &self.disabled)
322+
.field("notifier", &self.notifier.as_ref().map(|_| "..."))
323+
.finish()
324+
}
310325
}
311326

312327
impl<S> Default for ToolRouter<S> {
@@ -315,15 +330,18 @@ impl<S> Default for ToolRouter<S> {
315330
map: std::collections::HashMap::new(),
316331
transparent_when_not_found: false,
317332
disabled: std::collections::HashSet::new(),
333+
notifier: None,
318334
}
319335
}
320336
}
337+
321338
impl<S> Clone for ToolRouter<S> {
322339
fn clone(&self) -> Self {
323340
Self {
324341
map: self.map.clone(),
325342
transparent_when_not_found: self.transparent_when_not_found,
326343
disabled: self.disabled.clone(),
344+
notifier: self.notifier.clone(),
327345
}
328346
}
329347
}
@@ -346,11 +364,7 @@ where
346364
S: MaybeSend + 'static,
347365
{
348366
pub fn new() -> Self {
349-
Self {
350-
map: std::collections::HashMap::new(),
351-
transparent_when_not_found: false,
352-
disabled: std::collections::HashSet::new(),
353-
}
367+
Self::default()
354368
}
355369
pub fn with_route<R, A>(mut self, route: R) -> Self
356370
where
@@ -426,29 +440,30 @@ where
426440
self.map.contains_key(name) && !self.disabled.contains(name)
427441
}
428442

429-
/// Disable a tool by name so it is hidden from `list_all`, `get`, and
430-
/// rejected by `call`. The tool remains in the router and can be
431-
/// re-enabled later with [`enable_route`](Self::enable_route).
443+
/// Disable a tool by name. Hidden from `list_all`, `get`, rejected by
444+
/// `call`. Re-enable with [`enable_route`](Self::enable_route).
432445
///
433446
/// Returns `true` if the name was newly added to the disabled set.
434447
/// The name is recorded even if no matching route exists yet, so routes
435-
/// added later (via [`add_route`](Self::add_route) or
436-
/// [`merge`](Self::merge)) will inherit the disabled state.
437-
///
438-
/// Callers should send `Peer::notify_tool_list_changed` when the
439-
/// visible tool list changes. Accepts `&'static str` or `String`;
440-
/// for a non-static `&str`, call `.to_owned()` first.
448+
/// added later will inherit the disabled state.
441449
pub fn disable_route(&mut self, name: impl Into<Cow<'static, str>>) -> bool {
442-
self.disabled.insert(name.into())
450+
let name = name.into();
451+
let was_visible = self.map.contains_key(&name) && !self.disabled.contains(&name);
452+
let newly_disabled = self.disabled.insert(name.clone());
453+
if was_visible && newly_disabled {
454+
self.notify_if_visible(&name);
455+
}
456+
newly_disabled
443457
}
444458

445459
/// Re-enable a previously disabled tool. Returns `true` if the name
446-
/// was present in the disabled set and was removed.
447-
///
448-
/// Callers should send `Peer::notify_tool_list_changed` when the
449-
/// visible tool list changes.
460+
/// was in the disabled set.
450461
pub fn enable_route(&mut self, name: &str) -> bool {
451-
self.disabled.remove(name)
462+
let removed = self.disabled.remove(name);
463+
if removed {
464+
self.notify_if_visible(name);
465+
}
466+
removed
452467
}
453468

454469
/// Returns `true` if the tool exists in the router **and** is currently
@@ -468,6 +483,58 @@ where
468483
self
469484
}
470485

486+
/// Install a callback invoked when the visible tool list changes.
487+
pub fn set_notifier(&mut self, f: impl Fn() + Send + Sync + 'static) {
488+
self.notifier = Some(Arc::new(f));
489+
}
490+
491+
pub fn clear_notifier(&mut self) {
492+
self.notifier = None;
493+
}
494+
495+
/// Install a notifier that sends `notifications/tools/list_changed`
496+
/// via the given peer.
497+
pub fn bind_peer_notifier(&mut self, peer: &crate::service::Peer<crate::RoleServer>) {
498+
let peer = peer.clone();
499+
self.set_notifier(move || {
500+
let peer = peer.clone();
501+
tokio::spawn(async move {
502+
if let Err(e) = peer.notify_tool_list_changed().await {
503+
tracing::warn!("failed to send tools/list_changed notification: {e}");
504+
}
505+
});
506+
});
507+
}
508+
509+
/// Deferred notifier: no-op until the peer slot is filled.
510+
pub(crate) fn deferred_peer_notifier() -> (
511+
impl Fn() + Send + Sync + 'static,
512+
Arc<std::sync::OnceLock<crate::service::Peer<crate::RoleServer>>>,
513+
) {
514+
let peer_slot =
515+
Arc::new(std::sync::OnceLock::<crate::service::Peer<crate::RoleServer>>::new());
516+
let slot_clone = peer_slot.clone();
517+
let notifier = move || {
518+
if let Some(peer) = slot_clone.get() {
519+
let peer = peer.clone();
520+
tokio::spawn(async move {
521+
if let Err(e) = peer.notify_tool_list_changed().await {
522+
tracing::warn!("failed to send tools/list_changed notification: {e}");
523+
}
524+
});
525+
}
526+
};
527+
(notifier, peer_slot)
528+
}
529+
530+
fn notify_if_visible(&self, name: &str) {
531+
if self.map.contains_key(name) {
532+
if let Some(notifier) = &self.notifier {
533+
(notifier)();
534+
}
535+
}
536+
}
537+
471538
pub async fn call(
472539
&self,
473540
context: ToolCallContext<'_, S>,

0 commit comments

Comments
 (0)