Skip to content

Commit 8750031

Browse files
committed
feat(router): support runtime disabling of tools
Add methods to disable/enable tools at runtime. Disabled tools are hidden from listing, lookup, and execution, including in composed routers. Closes #477
1 parent 6603c1f commit 8750031

3 files changed

Lines changed: 194 additions & 5 deletions

File tree

crates/rmcp/src/handler/server/router.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ where
8484
match request {
8585
ClientRequest::CallToolRequest(request) => {
8686
if self.tool_router.has_route(request.params.name.as_ref())
87+
|| self.tool_router.is_disabled(request.params.name.as_ref())
8788
|| !self.tool_router.transparent_when_not_found
8889
{
8990
let tool_call_context = crate::handler::server::tool::ToolCallContext::new(

crates/rmcp/src/handler/server/router/tool.rs

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,16 @@ pub struct ToolRouter<S> {
305305
pub map: std::collections::HashMap<Cow<'static, str>, ToolRoute<S>>,
306306

307307
pub transparent_when_not_found: bool,
308+
309+
disabled: std::collections::HashSet<Cow<'static, str>>,
308310
}
309311

310312
impl<S> Default for ToolRouter<S> {
311313
fn default() -> Self {
312314
Self {
313315
map: std::collections::HashMap::new(),
314316
transparent_when_not_found: false,
317+
disabled: std::collections::HashSet::new(),
315318
}
316319
}
317320
}
@@ -320,6 +323,7 @@ impl<S> Clone for ToolRouter<S> {
320323
Self {
321324
map: self.map.clone(),
322325
transparent_when_not_found: self.transparent_when_not_found,
326+
disabled: self.disabled.clone(),
323327
}
324328
}
325329
}
@@ -329,7 +333,11 @@ impl<S> IntoIterator for ToolRouter<S> {
329333
type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, ToolRoute<S>>;
330334

331335
fn into_iter(self) -> Self::IntoIter {
332-
self.map.into_values()
336+
let mut map = self.map;
337+
for name in &self.disabled {
338+
map.remove(name);
339+
}
340+
map.into_values()
333341
}
334342
}
335343

@@ -341,6 +349,7 @@ where
341349
Self {
342350
map: std::collections::HashMap::new(),
343351
transparent_when_not_found: false,
352+
disabled: std::collections::HashSet::new(),
344353
}
345354
}
346355
pub fn with_route<R, A>(mut self, route: R) -> Self
@@ -394,24 +403,64 @@ where
394403
}
395404

396405
pub fn merge(&mut self, other: ToolRouter<S>) {
406+
self.disabled.extend(other.disabled);
397407
for item in other.map.into_values() {
398408
self.add_route(item);
399409
}
400410
}
401411

402412
pub fn remove_route(&mut self, name: &str) {
403413
self.map.remove(name);
414+
self.disabled.remove(name);
404415
}
416+
405417
pub fn has_route(&self, name: &str) -> bool {
406-
self.map.contains_key(name)
418+
self.map.contains_key(name) && !self.disabled.contains(name)
419+
}
420+
421+
/// Disable a tool by name so it is hidden from `list_all`, `get`, and
422+
/// rejected by `call`. The tool remains in the router and can be
423+
/// re-enabled later with [`enable_route`](Self::enable_route).
424+
///
425+
/// The name is recorded even if no matching route exists yet, so routes
426+
/// added later (via [`add_route`](Self::add_route) or
427+
/// [`merge`](Self::merge)) will inherit the disabled state.
428+
pub fn disable_route(&mut self, name: &str) {
429+
self.disabled.insert(Cow::Owned(name.to_owned()));
430+
}
431+
432+
/// Re-enable a previously disabled tool.
433+
pub fn enable_route(&mut self, name: &str) {
434+
self.disabled.remove(name);
435+
}
436+
437+
/// Returns `true` if the tool exists in the router but is currently
438+
/// disabled.
439+
pub fn is_disabled(&self, name: &str) -> bool {
440+
self.map.contains_key(name) && self.disabled.contains(name)
441+
}
442+
443+
/// Builder-style variant of [`disable_route`](Self::disable_route).
444+
///
445+
/// The name is recorded even if no matching route has been added yet,
446+
/// so it can be called before [`with_route`](Self::with_route) in a
447+
/// builder chain.
448+
pub fn with_disabled(mut self, name: impl Into<Cow<'static, str>>) -> Self {
449+
self.disabled.insert(name.into());
450+
self
407451
}
452+
408453
pub async fn call(
409454
&self,
410455
context: ToolCallContext<'_, S>,
411456
) -> Result<CallToolResult, crate::ErrorData> {
457+
let name = context.name();
458+
if self.disabled.contains(name) {
459+
return Err(crate::ErrorData::invalid_params("tool not found", None));
460+
}
412461
let item = self
413462
.map
414-
.get(context.name())
463+
.get(name)
415464
.ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?;
416465

417466
let result = (item.call)(context).await?;
@@ -420,15 +469,24 @@ where
420469
}
421470

422471
pub fn list_all(&self) -> Vec<crate::model::Tool> {
423-
let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect();
472+
let mut tools: Vec<_> = self
473+
.map
474+
.values()
475+
.filter(|item| !self.disabled.contains(&item.attr.name))
476+
.map(|item| item.attr.clone())
477+
.collect();
424478
tools.sort_by(|a, b| a.name.cmp(&b.name));
425479
tools
426480
}
427481

428482
/// Get a tool definition by name.
429483
///
430-
/// Returns the tool if found, or `None` if no tool with the given name exists.
484+
/// Returns the tool if found and enabled, or `None` if the tool does not
485+
/// exist or is disabled.
431486
pub fn get(&self, name: &str) -> Option<&crate::model::Tool> {
487+
if self.disabled.contains(name) {
488+
return None;
489+
}
432490
self.map.get(name).map(|r| &r.attr)
433491
}
434492
}

crates/rmcp/tests/test_tool_routers.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,133 @@ fn test_tool_router_list_all_is_sorted() {
8484
"list_all() should return tools sorted alphabetically by name"
8585
);
8686
}
87+
88+
fn build_router() -> ToolRouter<TestHandler<()>> {
89+
ToolRouter::<TestHandler<()>>::new()
90+
.with_route((async_function_tool_attr(), async_function))
91+
.with_route((async_function2_tool_attr(), async_function2))
92+
+ TestHandler::<()>::test_router_1()
93+
+ TestHandler::<()>::test_router_2()
94+
}
95+
96+
#[test]
97+
fn test_disable_route() {
98+
let mut router = build_router();
99+
assert_eq!(router.list_all().len(), 4);
100+
assert!(router.has_route("async_function"));
101+
assert!(router.get("async_function").is_some());
102+
103+
router.disable_route("async_function");
104+
105+
assert_eq!(router.list_all().len(), 3);
106+
assert!(!router.has_route("async_function"));
107+
assert!(router.get("async_function").is_none());
108+
assert!(router.is_disabled("async_function"));
109+
110+
// other tools unaffected
111+
assert!(router.has_route("async_function2"));
112+
assert!(router.get("async_function2").is_some());
113+
assert!(!router.is_disabled("async_function2"));
114+
}
115+
116+
#[test]
117+
fn test_enable_route() {
118+
let mut router = build_router();
119+
router.disable_route("async_function");
120+
assert!(!router.has_route("async_function"));
121+
122+
router.enable_route("async_function");
123+
assert!(router.has_route("async_function"));
124+
assert!(router.get("async_function").is_some());
125+
assert!(!router.is_disabled("async_function"));
126+
assert_eq!(router.list_all().len(), 4);
127+
}
128+
129+
#[test]
130+
fn test_with_disabled_builder() {
131+
let router = build_router()
132+
.with_disabled("async_function")
133+
.with_disabled("sync_method");
134+
135+
assert_eq!(router.list_all().len(), 2);
136+
assert!(!router.has_route("async_function"));
137+
assert!(!router.has_route("sync_method"));
138+
assert!(router.has_route("async_function2"));
139+
assert!(router.has_route("async_method"));
140+
}
141+
142+
#[test]
143+
fn test_disabled_tools_survive_merge() {
144+
let mut router_a = ToolRouter::<TestHandler<()>>::new()
145+
.with_route((async_function_tool_attr(), async_function));
146+
router_a.disable_route("async_function");
147+
148+
let router_b = ToolRouter::<TestHandler<()>>::new()
149+
.with_route((async_function2_tool_attr(), async_function2));
150+
151+
router_a.merge(router_b);
152+
153+
assert_eq!(router_a.list_all().len(), 1);
154+
assert!(router_a.is_disabled("async_function"));
155+
assert!(router_a.has_route("async_function2"));
156+
}
157+
158+
#[test]
159+
fn test_disable_nonexistent_tool() {
160+
let mut router = build_router();
161+
// should not panic
162+
router.disable_route("does_not_exist");
163+
assert_eq!(router.list_all().len(), 4);
164+
// is_disabled returns false for tools not in the map
165+
assert!(!router.is_disabled("does_not_exist"));
166+
}
167+
168+
#[test]
169+
fn test_remove_route_clears_disabled_state() {
170+
let mut router = build_router();
171+
router.disable_route("async_function");
172+
assert!(router.is_disabled("async_function"));
173+
174+
router.remove_route("async_function");
175+
assert!(!router.is_disabled("async_function"));
176+
assert!(!router.has_route("async_function"));
177+
}
178+
179+
#[test]
180+
fn test_into_iter_skips_disabled() {
181+
let router = build_router().with_disabled("async_function");
182+
let names: Vec<_> = router
183+
.into_iter()
184+
.map(|r| r.attr.name.to_string())
185+
.collect();
186+
assert_eq!(names.len(), 3);
187+
assert!(!names.contains(&"async_function".to_string()));
188+
}
189+
190+
#[test]
191+
fn test_pre_disable_before_add_route() {
192+
// Disabling a name before adding a route with that name should
193+
// result in the route being disabled once added.
194+
let router = ToolRouter::<TestHandler<()>>::new()
195+
.with_disabled("async_function")
196+
.with_route((async_function_tool_attr(), async_function));
197+
198+
assert_eq!(router.list_all().len(), 0);
199+
assert!(router.is_disabled("async_function"));
200+
assert!(!router.has_route("async_function"));
201+
}
202+
203+
#[test]
204+
fn test_disabled_tool_invisible_across_all_queries() {
205+
let router = build_router().with_disabled("async_function");
206+
207+
// Not listed
208+
let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect();
209+
assert!(!names.contains(&"async_function".into()));
210+
// Not retrievable
211+
assert!(router.get("async_function").is_none());
212+
// Not routable
213+
assert!(!router.has_route("async_function"));
214+
// But still known as disabled
215+
assert!(router.is_disabled("async_function"));
216+
}

0 commit comments

Comments
 (0)