Skip to content

Commit 16f8b8c

Browse files
committed
fix(router): simplify disable tool api
1 parent 8750031 commit 16f8b8c

3 files changed

Lines changed: 141 additions & 19 deletions

File tree

crates/rmcp/src/handler/server/router.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ where
8383
) -> Result<<RoleServer as crate::service::ServiceRole>::Resp, crate::ErrorData> {
8484
match request {
8585
ClientRequest::CallToolRequest(request) => {
86-
if self.tool_router.has_route(request.params.name.as_ref())
87-
|| self.tool_router.is_disabled(request.params.name.as_ref())
86+
if self
87+
.tool_router
88+
.map
89+
.contains_key(request.params.name.as_ref())
8890
|| !self.tool_router.transparent_when_not_found
8991
{
9092
let tool_call_context = crate::handler::server::tool::ToolCallContext::new(

crates/rmcp/src/handler/server/router/tool.rs

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,19 @@ where
409409
}
410410
}
411411

412+
/// Remove a tool route from the router.
413+
///
414+
/// The disabled state is **preserved**: if the name was in the disabled
415+
/// set, it stays there so that a future [`add_route`](Self::add_route)
416+
/// or [`merge`](Self::merge) with the same name will inherit the
417+
/// disabled state. To also clear the disabled marker, call
418+
/// [`enable_route`](Self::enable_route) afterwards.
412419
pub fn remove_route(&mut self, name: &str) {
413420
self.map.remove(name);
414-
self.disabled.remove(name);
415421
}
416422

423+
/// Returns `true` if the tool is registered **and** not currently
424+
/// disabled.
417425
pub fn has_route(&self, name: &str) -> bool {
418426
self.map.contains_key(name) && !self.disabled.contains(name)
419427
}
@@ -422,20 +430,30 @@ where
422430
/// rejected by `call`. The tool remains in the router and can be
423431
/// re-enabled later with [`enable_route`](Self::enable_route).
424432
///
433+
/// Returns `true` if the name was newly added to the disabled set.
425434
/// The name is recorded even if no matching route exists yet, so routes
426435
/// added later (via [`add_route`](Self::add_route) or
427436
/// [`merge`](Self::merge)) will inherit the disabled state.
428-
pub fn disable_route(&mut self, name: &str) {
429-
self.disabled.insert(Cow::Owned(name.to_owned()));
437+
///
438+
/// Callers should send `Peer::notify_tool_list_changed` when the
439+
/// visible tool list changes. Accepts `&'static str` or `String`;
440+
/// for a non-static `&str`, call `.to_owned()` first.
441+
pub fn disable_route(&mut self, name: impl Into<Cow<'static, str>>) -> bool {
442+
self.disabled.insert(name.into())
430443
}
431444

432-
/// Re-enable a previously disabled tool.
433-
pub fn enable_route(&mut self, name: &str) {
434-
self.disabled.remove(name);
445+
/// Re-enable a previously disabled tool. Returns `true` if the name
446+
/// was present in the disabled set and was removed.
447+
///
448+
/// Callers should send `Peer::notify_tool_list_changed` when the
449+
/// visible tool list changes.
450+
pub fn enable_route(&mut self, name: &str) -> bool {
451+
self.disabled.remove(name)
435452
}
436453

437-
/// Returns `true` if the tool exists in the router but is currently
438-
/// disabled.
454+
/// Returns `true` if the tool exists in the router **and** is currently
455+
/// disabled. Returns `false` if the tool does not exist or if the name
456+
/// was pre-disabled without a matching route.
439457
pub fn is_disabled(&self, name: &str) -> bool {
440458
self.map.contains_key(name) && self.disabled.contains(name)
441459
}
@@ -511,3 +529,49 @@ where
511529
self.merge(other);
512530
}
513531
}
532+
533+
#[cfg(test)]
534+
mod tests {
535+
use std::sync::Arc;
536+
537+
use super::*;
538+
use crate::{
539+
RoleServer,
540+
model::{CallToolRequestParams, ErrorCode, NumberOrString},
541+
service::{AtomicU32RequestIdProvider, Peer, RequestContext},
542+
};
543+
544+
struct DummyService;
545+
impl crate::handler::server::ServerHandler for DummyService {}
546+
547+
#[tokio::test]
548+
async fn test_call_disabled_tool_returns_error() {
549+
let service = DummyService;
550+
let mut router = ToolRouter::new().with_route(ToolRoute::new_dyn(
551+
crate::model::Tool::new("test_tool", "a test tool", Arc::new(Default::default())),
552+
|_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
553+
));
554+
router.disable_route("test_tool");
555+
556+
let id_provider: Arc<dyn crate::service::RequestIdProvider> =
557+
Arc::new(AtomicU32RequestIdProvider::default());
558+
let (peer, _rx) = Peer::<RoleServer>::new(id_provider, None);
559+
let ctx = crate::handler::server::tool::ToolCallContext::new(
560+
&service,
561+
CallToolRequestParams {
562+
meta: None,
563+
name: Cow::Borrowed("test_tool"),
564+
arguments: None,
565+
task: None,
566+
},
567+
RequestContext::new(NumberOrString::Number(1), peer),
568+
);
569+
570+
let err = router
571+
.call(ctx)
572+
.await
573+
.expect_err("disabled tool should reject");
574+
assert_eq!(err.code, ErrorCode::INVALID_PARAMS);
575+
assert_eq!(err.message, "tool not found");
576+
}
577+
}

crates/rmcp/tests/test_tool_routers.rs

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ fn test_disable_route() {
100100
assert!(router.has_route("async_function"));
101101
assert!(router.get("async_function").is_some());
102102

103-
router.disable_route("async_function");
103+
assert!(router.disable_route("async_function"));
104104

105105
assert_eq!(router.list_all().len(), 3);
106106
assert!(!router.has_route("async_function"));
@@ -116,10 +116,10 @@ fn test_disable_route() {
116116
#[test]
117117
fn test_enable_route() {
118118
let mut router = build_router();
119-
router.disable_route("async_function");
119+
assert!(router.disable_route("async_function"));
120120
assert!(!router.has_route("async_function"));
121121

122-
router.enable_route("async_function");
122+
assert!(router.enable_route("async_function"));
123123
assert!(router.has_route("async_function"));
124124
assert!(router.get("async_function").is_some());
125125
assert!(!router.is_disabled("async_function"));
@@ -143,7 +143,7 @@ fn test_with_disabled_builder() {
143143
fn test_disabled_tools_survive_merge() {
144144
let mut router_a = ToolRouter::<TestHandler<()>>::new()
145145
.with_route((async_function_tool_attr(), async_function));
146-
router_a.disable_route("async_function");
146+
assert!(router_a.disable_route("async_function"));
147147

148148
let router_b = ToolRouter::<TestHandler<()>>::new()
149149
.with_route((async_function2_tool_attr(), async_function2));
@@ -158,22 +158,42 @@ fn test_disabled_tools_survive_merge() {
158158
#[test]
159159
fn test_disable_nonexistent_tool() {
160160
let mut router = build_router();
161-
// should not panic
162-
router.disable_route("does_not_exist");
161+
// should not panic; returns true because the name is newly added to disabled set
162+
assert!(router.disable_route("does_not_exist"));
163163
assert_eq!(router.list_all().len(), 4);
164164
// is_disabled returns false for tools not in the map
165165
assert!(!router.is_disabled("does_not_exist"));
166166
}
167167

168168
#[test]
169-
fn test_remove_route_clears_disabled_state() {
169+
fn test_remove_route_preserves_disabled_state() {
170170
let mut router = build_router();
171-
router.disable_route("async_function");
171+
assert!(router.disable_route("async_function"));
172172
assert!(router.is_disabled("async_function"));
173173

174174
router.remove_route("async_function");
175+
assert!(!router.has_route("async_function"));
176+
// Disabled marker is preserved — is_disabled returns false (no route in map)
177+
// but re-adding will inherit the disabled state (tested separately)
175178
assert!(!router.is_disabled("async_function"));
179+
}
180+
181+
#[test]
182+
fn test_remove_route_then_readd_stays_disabled() {
183+
let mut router = build_router();
184+
assert!(router.disable_route("async_function"));
185+
186+
router.remove_route("async_function");
187+
assert!(!router.has_route("async_function"));
188+
189+
// Re-add the route — it should inherit the disabled state
190+
let other = ToolRouter::<TestHandler<()>>::new()
191+
.with_route((async_function_tool_attr(), async_function));
192+
router.merge(other);
193+
176194
assert!(!router.has_route("async_function"));
195+
assert!(router.is_disabled("async_function"));
196+
assert!(router.get("async_function").is_none());
177197
}
178198

179199
#[test]
@@ -211,6 +231,42 @@ fn test_disabled_tool_invisible_across_all_queries() {
211231
assert!(router.get("async_function").is_none());
212232
// Not routable
213233
assert!(!router.has_route("async_function"));
214-
// But still known as disabled
234+
// But known as disabled
215235
assert!(router.is_disabled("async_function"));
216236
}
237+
238+
#[test]
239+
fn test_disable_route_then_add_route_blocks_tool() {
240+
// Full pre-disable lifecycle via runtime mutation (not builder)
241+
let mut router = ToolRouter::<TestHandler<()>>::new();
242+
router.disable_route("async_function");
243+
244+
// Add route after disabling — tool should be blocked
245+
let other = ToolRouter::<TestHandler<()>>::new()
246+
.with_route((async_function_tool_attr(), async_function));
247+
router.merge(other);
248+
249+
assert!(router.is_disabled("async_function"));
250+
assert!(!router.has_route("async_function"));
251+
assert!(router.get("async_function").is_none());
252+
assert_eq!(router.list_all().len(), 0);
253+
}
254+
255+
#[test]
256+
fn test_disable_enable_return_false_cases() {
257+
let mut router = build_router();
258+
259+
// Repeated disable returns false
260+
assert!(router.disable_route("async_function"));
261+
assert!(!router.disable_route("async_function"));
262+
263+
// Enable returns true, then false on repeat
264+
assert!(router.enable_route("async_function"));
265+
assert!(!router.enable_route("async_function"));
266+
267+
// Enable on name never disabled returns false
268+
assert!(!router.enable_route("async_function2"));
269+
270+
// Enable on unknown name returns false
271+
assert!(!router.enable_route("unknown"));
272+
}

0 commit comments

Comments
 (0)