Skip to content

Commit 254f04a

Browse files
authored
fix: strip and validate tool outputSchema and inputSchema (#860)
* fix: remove unnecessary fields from tools' outputSchema * fix: validate input schema root type per MCP spec
1 parent 8f558d8 commit 254f04a

9 files changed

Lines changed: 158 additions & 65 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ async fn main() -> anyhow::Result<()> {
169169
}
170170
```
171171

172-
The generated tool `inputSchema` is derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used.
172+
The generated tool `inputSchema` and `outputSchema` are derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used.
173173

174174
When you need custom server metadata or multiple capabilities (tools + prompts), use explicit `#[tool_handler]`:
175175

crates/rmcp-macros/src/tool.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
232232
// if found, use the Parameters schema
233233
syn::parse2::<Expr>(quote! {
234234
rmcp::handler::server::common::schema_for_input::<#params_ty>()
235+
.unwrap_or_else(|e| {
236+
panic!(
237+
"Invalid input schema for `{}`: {}",
238+
std::any::type_name::<#params_ty>(),
239+
e
240+
)
241+
})
235242
})?
236243
} else {
237244
// if not found, use a default empty JSON schema object

crates/rmcp/src/handler/server/common.rs

Lines changed: 89 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
//! Common utilities shared between tool and prompt handlers
22
3-
use std::{any::TypeId, collections::HashMap, sync::Arc};
3+
use std::{
4+
any::TypeId,
5+
collections::HashMap,
6+
sync::{Arc, LazyLock},
7+
};
48

59
use schemars::JsonSchema;
610

@@ -30,12 +34,10 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
3034
let generator = settings.into_generator();
3135
let schema = generator.into_root_schema_for::<T>();
3236
let object = serde_json::to_value(schema).expect("failed to serialize schema");
33-
let object = match object {
34-
serde_json::Value::Object(object) => object,
35-
_ => panic!(
36-
"Schema serialization produced non-object value: expected JSON object but got {:?}",
37-
object
38-
),
37+
let serde_json::Value::Object(object) = object else {
38+
panic!(
39+
"Schema serialization produced non-object value: expected JSON object but got {object:?}"
40+
);
3941
};
4042
let schema = Arc::new(object);
4143
cache
@@ -48,51 +50,63 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
4850
})
4951
}
5052

51-
/// Generate a JSON schema for inputSchema (does not need "title" or "description" fields for the top-level object)
52-
pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
53+
/// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level
54+
/// `title`/`description` (the wrapper type name and doc, which are noise to the LLM).
55+
fn validate_and_strip(raw: &Arc<JsonObject>, purpose: &str) -> Result<Arc<JsonObject>, String> {
56+
match raw.get("type") {
57+
Some(serde_json::Value::String(t)) if t == "object" => {
58+
let mut object = raw.as_ref().clone();
59+
object.remove("title");
60+
object.remove("description");
61+
Ok(Arc::new(object))
62+
}
63+
Some(serde_json::Value::String(t)) => Err(format!(
64+
"MCP specification requires tool {purpose} to have root type 'object', but found '{t}'."
65+
)),
66+
None => Err(format!(
67+
"Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'."
68+
)),
69+
Some(other) => Err(format!(
70+
"Schema 'type' field has unexpected format: {other:?}. Expected \"object\"."
71+
)),
72+
}
73+
}
74+
75+
/// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object";
76+
/// top-level "title" and "description" are removed).
77+
pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
5378
thread_local! {
54-
static CACHE_FOR_INPUT: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
79+
static CACHE_FOR_INPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
5580
};
5681
CACHE_FOR_INPUT.with(|cache| {
57-
if let Some(schema) = cache
82+
if let Some(result) = cache
5883
.read()
5984
.expect("input schema cache lock poisoned")
6085
.get(&TypeId::of::<T>())
6186
{
62-
schema.clone()
63-
} else {
64-
let mut schema = schema_for_type::<T>().as_ref().clone();
65-
66-
// Remove unnecessary top-level fields
67-
schema.remove("title");
68-
schema.remove("description");
69-
70-
let schema = Arc::new(schema);
71-
cache
72-
.write()
73-
.expect("input schema cache lock poisoned")
74-
.insert(TypeId::of::<T>(), schema.clone());
75-
76-
schema
87+
return result.clone();
7788
}
89+
let result = validate_and_strip(&schema_for_type::<T>(), "inputSchema");
90+
cache
91+
.write()
92+
.expect("input schema cache lock poisoned")
93+
.insert(TypeId::of::<T>(), result.clone());
94+
result
7895
})
7996
}
8097

81-
// TODO: should be updated according to the new specifications
8298
/// Schema used when input is empty.
8399
pub fn schema_for_empty_input() -> Arc<JsonObject> {
84-
std::sync::Arc::new(
85-
serde_json::json!({
86-
"type": "object",
87-
"properties": {}
88-
})
89-
.as_object()
90-
.unwrap()
91-
.clone(),
92-
)
100+
static EMPTY: LazyLock<Arc<JsonObject>> = LazyLock::new(|| {
101+
let mut object = JsonObject::new();
102+
object.insert("type".into(), serde_json::json!("object"));
103+
object.insert("properties".into(), serde_json::json!({}));
104+
Arc::new(object)
105+
});
106+
EMPTY.clone()
93107
}
94108

95-
/// Generate and validate a JSON schema for outputSchema (must have root type "object").
109+
/// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed)
96110
pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
97111
thread_local! {
98112
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
@@ -108,22 +122,8 @@ pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObje
108122
return result.clone();
109123
}
110124

111-
// Generate and validate schema
112-
let schema = schema_for_type::<T>();
113-
let result = match schema.get("type") {
114-
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
115-
Some(serde_json::Value::String(t)) => Err(format!(
116-
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
117-
t
118-
)),
119-
None => Err(
120-
"Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
121-
),
122-
Some(other) => Err(format!(
123-
"Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
124-
other
125-
)),
126-
};
125+
// Generate, validate, and strip unnecessary top-level fields
126+
let result = validate_and_strip(&schema_for_type::<T>(), "outputSchema");
127127

128128
// Cache the result (both success and error cases)
129129
cache
@@ -316,4 +316,40 @@ mod tests {
316316
let result = schema_for_output::<TestObject>();
317317
assert!(result.is_ok(),);
318318
}
319+
320+
#[test]
321+
fn test_schema_for_output_strips_top_level_title() {
322+
let schema = schema_for_output::<TestObject>().unwrap();
323+
assert!(!schema.contains_key("title"));
324+
}
325+
326+
#[test]
327+
fn test_schema_for_output_strips_top_level_description() {
328+
let schema = schema_for_output::<TestObject>().unwrap();
329+
assert!(!schema.contains_key("description"));
330+
}
331+
332+
#[test]
333+
fn test_schema_for_input_rejects_primitive() {
334+
let result = schema_for_input::<i32>();
335+
assert!(result.is_err());
336+
}
337+
338+
#[test]
339+
fn test_schema_for_input_accepts_object() {
340+
let result = schema_for_input::<TestObject>();
341+
assert!(result.is_ok());
342+
}
343+
344+
#[test]
345+
fn test_schema_for_input_strips_top_level_title() {
346+
let schema = schema_for_input::<TestObject>().unwrap();
347+
assert!(!schema.contains_key("title"));
348+
}
349+
350+
#[test]
351+
fn test_schema_for_input_strips_top_level_description() {
352+
let schema = schema_for_input::<TestObject>().unwrap();
353+
assert!(!schema.contains_key("description"));
354+
}
319355
}

crates/rmcp/src/handler/server/router/tool.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ where
250250
attr: Tool::new(
251251
name.into(),
252252
"",
253-
schema_for_input::<crate::model::JsonObject>(),
253+
schema_for_input::<crate::model::JsonObject>().unwrap_or_else(|e| {
254+
panic!("Invalid input schema for JsonObject: {e}");
255+
}),
254256
),
255257
call: self,
256258
_marker: std::marker::PhantomData,
@@ -287,7 +289,12 @@ where
287289
self
288290
}
289291
pub fn parameters<T: JsonSchema + 'static>(mut self) -> Self {
290-
self.attr.input_schema = schema_for_input::<T>();
292+
self.attr.input_schema = schema_for_input::<T>().unwrap_or_else(|e| {
293+
panic!(
294+
"Invalid input schema for `{}`: {e}",
295+
std::any::type_name::<T>()
296+
)
297+
});
291298
self
292299
}
293300
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {

crates/rmcp/src/handler/server/router/tool/tool_traits.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ pub trait ToolBase {
4949
/// If the tool does not have any parameters, you should override this methods to return [`None`],
5050
/// and when invoked, the parameter will get default values.
5151
fn input_schema() -> Option<Arc<JsonObject>> {
52-
Some(schema_for_input::<Parameters<Self::Parameter>>())
52+
Some(
53+
schema_for_input::<Parameters<Self::Parameter>>().unwrap_or_else(|e| {
54+
panic!(
55+
"Invalid input schema for ToolBase::Parameter type `{0}`: {e}",
56+
std::any::type_name::<Self::Parameter>(),
57+
);
58+
}),
59+
)
5360
}
5461

5562
/// Json schema for tool output.

crates/rmcp/src/model/tool.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ impl Tool {
330330
/// Set the input schema using a type that implements JsonSchema
331331
#[cfg(feature = "server")]
332332
pub fn with_input_schema<T: JsonSchema + 'static>(mut self) -> Self {
333-
self.input_schema = crate::handler::server::tool::schema_for_input::<T>();
333+
self.input_schema = crate::handler::server::tool::schema_for_input::<T>()
334+
.unwrap_or_else(|e| panic!("Invalid input schema for tool '{}': {}", self.name, e));
334335
self
335336
}
336337

crates/rmcp/tests/test_list_tools_result.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![cfg(all(feature = "server", feature = "macros", not(feature = "local")))]
22

33
use rmcp::{
4+
Json,
45
handler::server::wrapper::Parameters,
56
model::{ListToolsResult, NumberOrString, ServerJsonRpcMessage, ServerResult},
67
};
@@ -14,10 +15,17 @@ struct AddRequest {
1415
b: f64,
1516
}
1617

18+
/// Result of adding two numbers.
19+
#[derive(Debug, serde::Serialize, schemars::JsonSchema)]
20+
struct AddResult {
21+
/// The sum of the two numbers.
22+
sum: f64,
23+
}
24+
1725
/// Add two numbers.
1826
#[rmcp::tool]
19-
fn add(Parameters(AddRequest { a, b }): Parameters<AddRequest>) -> String {
20-
(a + b).to_string()
27+
fn add(Parameters(AddRequest { a, b }): Parameters<AddRequest>) -> Json<AddResult> {
28+
Json(AddResult { sum: a + b })
2129
}
2230

2331
#[test]
@@ -27,7 +35,7 @@ fn list_tools_result_matches_expected_json() {
2735
let expected: serde_json::Value =
2836
serde_json::from_slice(&expected_json).expect("invalid expected JSON fixture");
2937

30-
assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })), "3");
38+
assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })).0.sum, 3.0);
3139

3240
let result = ListToolsResult::with_all_items(vec![add_tool_attr()]);
3341
let response = ServerJsonRpcMessage::response(

crates/rmcp/tests/test_list_tools_result/list_tools_result.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323
"a",
2424
"b"
2525
]
26+
},
27+
"outputSchema": {
28+
"$schema": "https://json-schema.org/draft/2020-12/schema",
29+
"type": "object",
30+
"properties": {
31+
"sum": {
32+
"description": "The sum of the two numbers.",
33+
"format": "double",
34+
"type": "number"
35+
}
36+
},
37+
"required": [
38+
"sum"
39+
]
2640
}
2741
}
2842
]

crates/rmcp/tests/test_structured_output.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ pub struct UserInfo {
2828
pub age: u32,
2929
}
3030

31+
#[derive(Serialize, Deserialize, JsonSchema)]
32+
pub struct GreetingRequest {
33+
pub name: String,
34+
}
35+
36+
#[derive(Serialize, Deserialize, JsonSchema)]
37+
pub struct GetUserRequest {
38+
pub user_id: String,
39+
}
40+
3141
#[tool_handler(router = self.tool_router)]
3242
impl ServerHandler for TestServer {}
3343

@@ -64,14 +74,17 @@ impl TestServer {
6474

6575
/// Tool that returns regular string output
6676
#[tool(name = "get-greeting", description = "Get a greeting")]
67-
pub async fn get_greeting(&self, name: Parameters<String>) -> String {
68-
format!("Hello, {}!", name.0)
77+
pub async fn get_greeting(&self, params: Parameters<GreetingRequest>) -> String {
78+
format!("Hello, {}!", params.0.name)
6979
}
7080

7181
/// Tool that returns structured user info
7282
#[tool(name = "get-user", description = "Get user info")]
73-
pub async fn get_user(&self, user_id: Parameters<String>) -> Result<Json<UserInfo>, String> {
74-
if user_id.0 == "123" {
83+
pub async fn get_user(
84+
&self,
85+
params: Parameters<GetUserRequest>,
86+
) -> Result<Json<UserInfo>, String> {
87+
if params.0.user_id == "123" {
7588
Ok(Json(UserInfo {
7689
name: "Alice".to_string(),
7790
age: 30,

0 commit comments

Comments
 (0)