Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/rmcp/src/handler/server/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ where
match request {
ClientRequest::CallToolRequest(request) => {
if self.tool_router.has_route(request.params.name.as_ref())
|| self.tool_router.is_disabled(request.params.name.as_ref())
|| !self.tool_router.transparent_when_not_found
{
let tool_call_context = crate::handler::server::tool::ToolCallContext::new(
Expand Down
68 changes: 63 additions & 5 deletions crates/rmcp/src/handler/server/router/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,16 @@ pub struct ToolRouter<S> {
pub map: std::collections::HashMap<Cow<'static, str>, ToolRoute<S>>,

pub transparent_when_not_found: bool,

disabled: std::collections::HashSet<Cow<'static, str>>,
}

impl<S> Default for ToolRouter<S> {
fn default() -> Self {
Self {
map: std::collections::HashMap::new(),
transparent_when_not_found: false,
disabled: std::collections::HashSet::new(),
}
}
}
Expand All @@ -320,6 +323,7 @@ impl<S> Clone for ToolRouter<S> {
Self {
map: self.map.clone(),
transparent_when_not_found: self.transparent_when_not_found,
disabled: self.disabled.clone(),
}
}
}
Expand All @@ -329,7 +333,11 @@ impl<S> IntoIterator for ToolRouter<S> {
type IntoIter = std::collections::hash_map::IntoValues<Cow<'static, str>, ToolRoute<S>>;

fn into_iter(self) -> Self::IntoIter {
self.map.into_values()
let mut map = self.map;
for name in &self.disabled {
map.remove(name);
}
map.into_values()
}
}

Expand All @@ -341,6 +349,7 @@ where
Self {
map: std::collections::HashMap::new(),
transparent_when_not_found: false,
disabled: std::collections::HashSet::new(),
}
}
pub fn with_route<R, A>(mut self, route: R) -> Self
Expand Down Expand Up @@ -394,24 +403,64 @@ where
}

pub fn merge(&mut self, other: ToolRouter<S>) {
self.disabled.extend(other.disabled);
for item in other.map.into_values() {
self.add_route(item);
}
}

pub fn remove_route(&mut self, name: &str) {
self.map.remove(name);
self.disabled.remove(name);
}

pub fn has_route(&self, name: &str) -> bool {
self.map.contains_key(name)
self.map.contains_key(name) && !self.disabled.contains(name)
}

/// Disable a tool by name so it is hidden from `list_all`, `get`, and
/// rejected by `call`. The tool remains in the router and can be
/// re-enabled later with [`enable_route`](Self::enable_route).
///
/// The name is recorded even if no matching route exists yet, so routes
/// added later (via [`add_route`](Self::add_route) or
/// [`merge`](Self::merge)) will inherit the disabled state.
pub fn disable_route(&mut self, name: &str) {
self.disabled.insert(Cow::Owned(name.to_owned()));
Comment thread
DaleSeo marked this conversation as resolved.
Outdated
}

/// Re-enable a previously disabled tool.
pub fn enable_route(&mut self, name: &str) {
self.disabled.remove(name);
}

/// Returns `true` if the tool exists in the router but is currently
/// disabled.
pub fn is_disabled(&self, name: &str) -> bool {
self.map.contains_key(name) && self.disabled.contains(name)
Comment thread
DaleSeo marked this conversation as resolved.
}

/// Builder-style variant of [`disable_route`](Self::disable_route).
///
/// The name is recorded even if no matching route has been added yet,
/// so it can be called before [`with_route`](Self::with_route) in a
/// builder chain.
pub fn with_disabled(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.disabled.insert(name.into());
self
}

pub async fn call(
&self,
context: ToolCallContext<'_, S>,
) -> Result<CallToolResult, crate::ErrorData> {
let name = context.name();
if self.disabled.contains(name) {
return Err(crate::ErrorData::invalid_params("tool not found", None));
}
let item = self
.map
.get(context.name())
.get(name)
.ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?;

let result = (item.call)(context).await?;
Expand All @@ -420,15 +469,24 @@ where
}

pub fn list_all(&self) -> Vec<crate::model::Tool> {
let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect();
let mut tools: Vec<_> = self
.map
.values()
.filter(|item| !self.disabled.contains(&item.attr.name))
.map(|item| item.attr.clone())
.collect();
tools.sort_by(|a, b| a.name.cmp(&b.name));
tools
}

/// Get a tool definition by name.
///
/// Returns the tool if found, or `None` if no tool with the given name exists.
/// Returns the tool if found and enabled, or `None` if the tool does not
/// exist or is disabled.
pub fn get(&self, name: &str) -> Option<&crate::model::Tool> {
if self.disabled.contains(name) {
return None;
}
self.map.get(name).map(|r| &r.attr)
}
}
Expand Down
130 changes: 130 additions & 0 deletions crates/rmcp/tests/test_tool_routers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,133 @@ fn test_tool_router_list_all_is_sorted() {
"list_all() should return tools sorted alphabetically by name"
);
}

fn build_router() -> ToolRouter<TestHandler<()>> {
ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function))
.with_route((async_function2_tool_attr(), async_function2))
+ TestHandler::<()>::test_router_1()
+ TestHandler::<()>::test_router_2()
}

#[test]
fn test_disable_route() {
let mut router = build_router();
assert_eq!(router.list_all().len(), 4);
assert!(router.has_route("async_function"));
assert!(router.get("async_function").is_some());

router.disable_route("async_function");

assert_eq!(router.list_all().len(), 3);
assert!(!router.has_route("async_function"));
assert!(router.get("async_function").is_none());
assert!(router.is_disabled("async_function"));

// other tools unaffected
assert!(router.has_route("async_function2"));
assert!(router.get("async_function2").is_some());
assert!(!router.is_disabled("async_function2"));
}

#[test]
fn test_enable_route() {
let mut router = build_router();
router.disable_route("async_function");
assert!(!router.has_route("async_function"));

router.enable_route("async_function");
assert!(router.has_route("async_function"));
assert!(router.get("async_function").is_some());
assert!(!router.is_disabled("async_function"));
assert_eq!(router.list_all().len(), 4);
}

#[test]
fn test_with_disabled_builder() {
let router = build_router()
.with_disabled("async_function")
.with_disabled("sync_method");

assert_eq!(router.list_all().len(), 2);
assert!(!router.has_route("async_function"));
assert!(!router.has_route("sync_method"));
assert!(router.has_route("async_function2"));
assert!(router.has_route("async_method"));
}

#[test]
fn test_disabled_tools_survive_merge() {
let mut router_a = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function));
router_a.disable_route("async_function");

let router_b = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function2_tool_attr(), async_function2));

router_a.merge(router_b);

assert_eq!(router_a.list_all().len(), 1);
assert!(router_a.is_disabled("async_function"));
assert!(router_a.has_route("async_function2"));
}

#[test]
fn test_disable_nonexistent_tool() {
let mut router = build_router();
// should not panic
router.disable_route("does_not_exist");
assert_eq!(router.list_all().len(), 4);
// is_disabled returns false for tools not in the map
assert!(!router.is_disabled("does_not_exist"));
}

#[test]
fn test_remove_route_clears_disabled_state() {
let mut router = build_router();
router.disable_route("async_function");
assert!(router.is_disabled("async_function"));

router.remove_route("async_function");
assert!(!router.is_disabled("async_function"));
assert!(!router.has_route("async_function"));
}

#[test]
fn test_into_iter_skips_disabled() {
let router = build_router().with_disabled("async_function");
let names: Vec<_> = router
.into_iter()
.map(|r| r.attr.name.to_string())
.collect();
assert_eq!(names.len(), 3);
assert!(!names.contains(&"async_function".to_string()));
}

#[test]
fn test_pre_disable_before_add_route() {
// Disabling a name before adding a route with that name should
// result in the route being disabled once added.
let router = ToolRouter::<TestHandler<()>>::new()
.with_disabled("async_function")
.with_route((async_function_tool_attr(), async_function));

assert_eq!(router.list_all().len(), 0);
assert!(router.is_disabled("async_function"));
assert!(!router.has_route("async_function"));
}

#[test]
fn test_disabled_tool_invisible_across_all_queries() {
let router = build_router().with_disabled("async_function");

// Not listed
let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect();
assert!(!names.contains(&"async_function".into()));
// Not retrievable
assert!(router.get("async_function").is_none());
// Not routable
assert!(!router.has_route("async_function"));
// But still known as disabled
assert!(router.is_disabled("async_function"));
}
Loading