11//! Common utilities shared between tool and prompt handlers
22
3- use std:: { any:: TypeId , collections:: HashMap , sync:: Arc } ;
3+ use std:: {
4+ any:: TypeId ,
5+ collections:: HashMap ,
6+ sync:: { Arc , LazyLock } ,
7+ } ;
48
59use schemars:: JsonSchema ;
610
@@ -30,12 +34,10 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
3034 let generator = settings. into_generator ( ) ;
3135 let schema = generator. into_root_schema_for :: < T > ( ) ;
3236 let object = serde_json:: to_value ( schema) . expect ( "failed to serialize schema" ) ;
33- let object = match object {
34- serde_json:: Value :: Object ( object) => object,
35- _ => panic ! (
36- "Schema serialization produced non-object value: expected JSON object but got {:?}" ,
37- object
38- ) ,
37+ let serde_json:: Value :: Object ( object) = object else {
38+ panic ! (
39+ "Schema serialization produced non-object value: expected JSON object but got {object:?}"
40+ ) ;
3941 } ;
4042 let schema = Arc :: new ( object) ;
4143 cache
@@ -48,51 +50,63 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
4850 } )
4951}
5052
51- /// Generate a JSON schema for inputSchema (does not need "title" or "description" fields for the top-level object)
52- pub fn schema_for_input < T : JsonSchema + std:: any:: Any > ( ) -> Arc < JsonObject > {
53+ /// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level
54+ /// `title`/`description` (the wrapper type name and doc, which are noise to the LLM).
55+ fn validate_and_strip ( raw : & Arc < JsonObject > , purpose : & str ) -> Result < Arc < JsonObject > , String > {
56+ match raw. get ( "type" ) {
57+ Some ( serde_json:: Value :: String ( t) ) if t == "object" => {
58+ let mut object = raw. as_ref ( ) . clone ( ) ;
59+ object. remove ( "title" ) ;
60+ object. remove ( "description" ) ;
61+ Ok ( Arc :: new ( object) )
62+ }
63+ Some ( serde_json:: Value :: String ( t) ) => Err ( format ! (
64+ "MCP specification requires tool {purpose} to have root type 'object', but found '{t}'."
65+ ) ) ,
66+ None => Err ( format ! (
67+ "Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'."
68+ ) ) ,
69+ Some ( other) => Err ( format ! (
70+ "Schema 'type' field has unexpected format: {other:?}. Expected \" object\" ."
71+ ) ) ,
72+ }
73+ }
74+
75+ /// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object";
76+ /// top-level "title" and "description" are removed).
77+ pub fn schema_for_input < T : JsonSchema + std:: any:: Any > ( ) -> Result < Arc < JsonObject > , String > {
5378 thread_local ! {
54- static CACHE_FOR_INPUT : std:: sync:: RwLock <HashMap <TypeId , Arc <JsonObject >>> = Default :: default ( ) ;
79+ static CACHE_FOR_INPUT : std:: sync:: RwLock <HashMap <TypeId , Result < Arc <JsonObject > , String >>> = Default :: default ( ) ;
5580 } ;
5681 CACHE_FOR_INPUT . with ( |cache| {
57- if let Some ( schema ) = cache
82+ if let Some ( result ) = cache
5883 . read ( )
5984 . expect ( "input schema cache lock poisoned" )
6085 . get ( & TypeId :: of :: < T > ( ) )
6186 {
62- schema. clone ( )
63- } else {
64- let mut schema = schema_for_type :: < T > ( ) . as_ref ( ) . clone ( ) ;
65-
66- // Remove unnecessary top-level fields
67- schema. remove ( "title" ) ;
68- schema. remove ( "description" ) ;
69-
70- let schema = Arc :: new ( schema) ;
71- cache
72- . write ( )
73- . expect ( "input schema cache lock poisoned" )
74- . insert ( TypeId :: of :: < T > ( ) , schema. clone ( ) ) ;
75-
76- schema
87+ return result. clone ( ) ;
7788 }
89+ let result = validate_and_strip ( & schema_for_type :: < T > ( ) , "inputSchema" ) ;
90+ cache
91+ . write ( )
92+ . expect ( "input schema cache lock poisoned" )
93+ . insert ( TypeId :: of :: < T > ( ) , result. clone ( ) ) ;
94+ result
7895 } )
7996}
8097
81- // TODO: should be updated according to the new specifications
8298/// Schema used when input is empty.
8399pub fn schema_for_empty_input ( ) -> Arc < JsonObject > {
84- std:: sync:: Arc :: new (
85- serde_json:: json!( {
86- "type" : "object" ,
87- "properties" : { }
88- } )
89- . as_object ( )
90- . unwrap ( )
91- . clone ( ) ,
92- )
100+ static EMPTY : LazyLock < Arc < JsonObject > > = LazyLock :: new ( || {
101+ let mut object = JsonObject :: new ( ) ;
102+ object. insert ( "type" . into ( ) , serde_json:: json!( "object" ) ) ;
103+ object. insert ( "properties" . into ( ) , serde_json:: json!( { } ) ) ;
104+ Arc :: new ( object)
105+ } ) ;
106+ EMPTY . clone ( )
93107}
94108
95- /// Generate and validate a JSON schema for outputSchema (must have root type "object").
109+ /// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed)
96110pub fn schema_for_output < T : JsonSchema + std:: any:: Any > ( ) -> Result < Arc < JsonObject > , String > {
97111 thread_local ! {
98112 static CACHE_FOR_OUTPUT : std:: sync:: RwLock <HashMap <TypeId , Result <Arc <JsonObject >, String >>> = Default :: default ( ) ;
@@ -108,22 +122,8 @@ pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObje
108122 return result. clone ( ) ;
109123 }
110124
111- // Generate and validate schema
112- let schema = schema_for_type :: < T > ( ) ;
113- let result = match schema. get ( "type" ) {
114- Some ( serde_json:: Value :: String ( t) ) if t == "object" => Ok ( schema. clone ( ) ) ,
115- Some ( serde_json:: Value :: String ( t) ) => Err ( format ! (
116- "MCP specification requires tool outputSchema to have root type 'object', but found '{}'." ,
117- t
118- ) ) ,
119- None => Err (
120- "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'." . to_string ( )
121- ) ,
122- Some ( other) => Err ( format ! (
123- "Schema 'type' field has unexpected format: {:?}. Expected \" object\" ." ,
124- other
125- ) ) ,
126- } ;
125+ // Generate, validate, and strip unnecessary top-level fields
126+ let result = validate_and_strip ( & schema_for_type :: < T > ( ) , "outputSchema" ) ;
127127
128128 // Cache the result (both success and error cases)
129129 cache
@@ -316,4 +316,40 @@ mod tests {
316316 let result = schema_for_output :: < TestObject > ( ) ;
317317 assert ! ( result. is_ok( ) , ) ;
318318 }
319+
320+ #[ test]
321+ fn test_schema_for_output_strips_top_level_title ( ) {
322+ let schema = schema_for_output :: < TestObject > ( ) . unwrap ( ) ;
323+ assert ! ( !schema. contains_key( "title" ) ) ;
324+ }
325+
326+ #[ test]
327+ fn test_schema_for_output_strips_top_level_description ( ) {
328+ let schema = schema_for_output :: < TestObject > ( ) . unwrap ( ) ;
329+ assert ! ( !schema. contains_key( "description" ) ) ;
330+ }
331+
332+ #[ test]
333+ fn test_schema_for_input_rejects_primitive ( ) {
334+ let result = schema_for_input :: < i32 > ( ) ;
335+ assert ! ( result. is_err( ) ) ;
336+ }
337+
338+ #[ test]
339+ fn test_schema_for_input_accepts_object ( ) {
340+ let result = schema_for_input :: < TestObject > ( ) ;
341+ assert ! ( result. is_ok( ) ) ;
342+ }
343+
344+ #[ test]
345+ fn test_schema_for_input_strips_top_level_title ( ) {
346+ let schema = schema_for_input :: < TestObject > ( ) . unwrap ( ) ;
347+ assert ! ( !schema. contains_key( "title" ) ) ;
348+ }
349+
350+ #[ test]
351+ fn test_schema_for_input_strips_top_level_description ( ) {
352+ let schema = schema_for_input :: < TestObject > ( ) . unwrap ( ) ;
353+ assert ! ( !schema. contains_key( "description" ) ) ;
354+ }
319355}
0 commit comments