Skip to content

Commit c0cfbaa

Browse files
committed
fix: remove unnecessary fields from tools' inputSchema
1 parent 3529c36 commit c0cfbaa

10 files changed

Lines changed: 114 additions & 10 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ async fn main() -> anyhow::Result<()> {
169169
}
170170
```
171171

172+
The generated tool `inputSchema` is derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used.
173+
172174
When you need custom server metadata or multiple capabilities (tools + prompts), use explicit `#[tool_handler]`:
173175

174176
```rust,ignore

crates/rmcp-macros/src/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
231231
if let Some(params_ty) = params_ty {
232232
// if found, use the Parameters schema
233233
syn::parse2::<Expr>(quote! {
234-
rmcp::handler::server::common::schema_for_type::<#params_ty>()
234+
rmcp::handler::server::common::schema_for_input::<#params_ty>()
235235
})?
236236
} else {
237237
// if not found, use a default empty JSON schema object

crates/rmcp/src/handler/server/common.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,36 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
4848
})
4949
}
5050

51+
/// Generate a JSON schema for inputSchema (does not need "title" or "description" fields for the top-level object)
52+
pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
53+
thread_local! {
54+
static CACHE_FOR_INPUT: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
55+
};
56+
CACHE_FOR_INPUT.with(|cache| {
57+
if let Some(schema) = cache
58+
.read()
59+
.expect("input schema cache lock poisoned")
60+
.get(&TypeId::of::<T>())
61+
{
62+
schema.clone()
63+
} else {
64+
let mut schema = schema_for_type::<T>().as_ref().clone();
65+
66+
// Remove unnecessary top-level fields
67+
schema.remove("title");
68+
schema.remove("description");
69+
70+
let schema = Arc::new(schema);
71+
cache
72+
.write()
73+
.expect("input schema cache lock poisoned")
74+
.insert(TypeId::of::<T>(), schema.clone());
75+
76+
schema
77+
}
78+
})
79+
}
80+
5181
// TODO: should be updated according to the new specifications
5282
/// Schema used when input is empty.
5383
pub fn schema_for_empty_input() -> Arc<JsonObject> {

crates/rmcp/src/handler/server/router/tool.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ pub use tool_traits::{AsyncTool, SyncTool, ToolBase};
133133

134134
use crate::{
135135
handler::server::{
136-
tool::{CallToolHandler, DynCallToolHandler, ToolCallContext, schema_for_type},
136+
common::schema_for_input,
137+
tool::{CallToolHandler, DynCallToolHandler, ToolCallContext},
137138
tool_name_validation::validate_and_warn_tool_name,
138139
},
139140
model::{CallToolResult, Tool, ToolAnnotations},
@@ -249,7 +250,7 @@ where
249250
attr: Tool::new(
250251
name.into(),
251252
"",
252-
schema_for_type::<crate::model::JsonObject>(),
253+
schema_for_input::<crate::model::JsonObject>(),
253254
),
254255
call: self,
255256
_marker: std::marker::PhantomData,
@@ -286,7 +287,7 @@ where
286287
self
287288
}
288289
pub fn parameters<T: JsonSchema + 'static>(mut self) -> Self {
289-
self.attr.input_schema = schema_for_type::<T>();
290+
self.attr.input_schema = schema_for_input::<T>();
290291
self
291292
}
292293
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {

crates/rmcp/src/handler/server/router/tool/tool_traits.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use serde::{Deserialize, Serialize};
55
use crate::{
66
ErrorData,
77
handler::server::{
8-
common::schema_for_empty_input,
9-
tool::{schema_for_output, schema_for_type},
8+
common::{schema_for_empty_input, schema_for_input},
9+
tool::schema_for_output,
1010
wrapper::{Json, Parameters},
1111
},
1212
model::{Icon, JsonObject, Meta, ToolAnnotations, ToolExecution},
@@ -49,7 +49,7 @@ pub trait ToolBase {
4949
/// If the tool does not have any parameters, you should override this methods to return [`None`],
5050
/// and when invoked, the parameter will get default values.
5151
fn input_schema() -> Option<Arc<JsonObject>> {
52-
Some(schema_for_type::<Parameters<Self::Parameter>>())
52+
Some(schema_for_input::<Parameters<Self::Parameter>>())
5353
}
5454

5555
/// Json schema for tool output.

crates/rmcp/src/handler/server/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use serde::de::DeserializeOwned;
1010

1111
use super::common::{AsRequestContext, FromContextPart};
1212
pub use super::{
13-
common::{Extension, RequestId, schema_for_output, schema_for_type},
13+
common::{Extension, RequestId, schema_for_input, schema_for_output, schema_for_type},
1414
router::tool::{ToolRoute, ToolRouter},
1515
};
1616
use crate::{

crates/rmcp/src/model/tool.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ impl Tool {
330330
/// Set the input schema using a type that implements JsonSchema
331331
#[cfg(feature = "server")]
332332
pub fn with_input_schema<T: JsonSchema + 'static>(mut self) -> Self {
333-
self.input_schema = crate::handler::server::tool::schema_for_type::<T>();
333+
self.input_schema = crate::handler::server::tool::schema_for_input::<T>();
334334
self
335335
}
336336

crates/rmcp/tests/test_complex_schema.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ fn expected_schema() -> serde_json::Value {
9191
"required": [
9292
"messages"
9393
],
94-
"title": "ChatRequest",
9594
"type": "object"
9695
})
9796
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#![cfg(all(feature = "server", feature = "macros", not(feature = "local")))]
2+
3+
use rmcp::{
4+
handler::server::wrapper::Parameters,
5+
model::{ListToolsResult, NumberOrString, ServerJsonRpcMessage, ServerResult},
6+
};
7+
8+
/// Parameters for adding two numbers.
9+
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
10+
struct AddRequest {
11+
/// The left-hand number.
12+
a: f64,
13+
/// The right-hand number.
14+
b: f64,
15+
}
16+
17+
/// Add two numbers.
18+
#[rmcp::tool]
19+
fn add(Parameters(AddRequest { a, b }): Parameters<AddRequest>) -> String {
20+
(a + b).to_string()
21+
}
22+
23+
#[test]
24+
fn list_tools_result_matches_expected_json() {
25+
let expected_json = std::fs::read("tests/test_list_tools_result/list_tools_result.json")
26+
.expect("missing expected list tools result JSON fixture");
27+
let expected: serde_json::Value =
28+
serde_json::from_slice(&expected_json).expect("invalid expected JSON fixture");
29+
30+
assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })), "3");
31+
32+
let result = ListToolsResult::with_all_items(vec![add_tool_attr()]);
33+
let response = ServerJsonRpcMessage::response(
34+
ServerResult::ListToolsResult(result),
35+
NumberOrString::Number(2),
36+
);
37+
38+
let actual = serde_json::to_value(response).expect("failed to serialize list tools response");
39+
assert_eq!(actual, expected);
40+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"result": {
3+
"tools": [
4+
{
5+
"name": "add",
6+
"description": "Add two numbers.",
7+
"inputSchema": {
8+
"$schema": "https://json-schema.org/draft/2020-12/schema",
9+
"type": "object",
10+
"properties": {
11+
"a": {
12+
"description": "The left-hand number.",
13+
"format": "double",
14+
"type": "number"
15+
},
16+
"b": {
17+
"description": "The right-hand number.",
18+
"format": "double",
19+
"type": "number"
20+
}
21+
},
22+
"required": [
23+
"a",
24+
"b"
25+
]
26+
}
27+
}
28+
]
29+
},
30+
"jsonrpc": "2.0",
31+
"id": 2
32+
}

0 commit comments

Comments
 (0)