55import json
66import os
77import threading
8- from typing import Any
8+ from typing import TYPE_CHECKING , Any
9+
10+ if TYPE_CHECKING :
11+ from collections .abc import Awaitable , Callable
912
1013from mcp import StdioServerParameters
1114from mcp .client .session import ClientSession
@@ -27,7 +30,7 @@ def __init__(self, server_config: dict[str, Any], *, function_calling: bool = Tr
2730 self ._system_prompt = None
2831 self .function_calling = function_calling
2932
30- def call_tool (self , tool : str , args : dict [str , Any ]) -> Any :
33+ def call_tool (self , tool : str , args : dict [str , Any ]) -> object :
3134 with self .lock :
3235 return self ._call_tool_sync (tool , args )
3336
@@ -46,48 +49,50 @@ def system_prompt(self) -> str:
4649 def close (self ) -> None :
4750 pass # クライアントの状態管理が不要になったため何もしない
4851
49- def _run_async (self , coro ) :
52+ def _run_async (self , coro : Awaitable [ object ]) -> object :
5053 loop = asyncio .new_event_loop ()
5154 try :
5255 return loop .run_until_complete (coro )
5356 finally :
5457 loop .close ()
5558
56- def _run_with_session (self , coro_fn ) :
57- async def wrapper ():
59+ def _run_with_session (self , coro_fn : Callable [[ ClientSession ], Awaitable [ object ]]) -> object :
60+ async def wrapper () -> object :
5861 cmd = self .server_config ["command" ][0 ]
5962 args = self .server_config ["command" ][1 :]
6063 env = self .server_config .get ("env" , {})
6164 merged_env = dict (os .environ )
6265 merged_env .update (env )
6366 server_params = StdioServerParameters (command = cmd , args = args , env = merged_env )
64- async with stdio_client (server_params ) as (read_stream , write_stream ):
65- async with ClientSession (read_stream , write_stream ) as session :
66- await session .initialize ()
67- # notifications/initialized送信
68- notification = ClientNotification (
69- InitializedNotification (method = "notifications/initialized" ),
70- )
71- await session .send_notification (notification )
72- return await coro_fn (session )
67+ async with (
68+ stdio_client (server_params ) as (read_stream , write_stream ),
69+ ClientSession (read_stream , write_stream ) as session ,
70+ ):
71+ await session .initialize ()
72+ # notifications/initialized送信
73+ notification = ClientNotification (
74+ InitializedNotification (method = "notifications/initialized" ),
75+ )
76+ await session .send_notification (notification )
77+ return await coro_fn (session )
7378
7479 return self ._run_async (wrapper ())
7580
7681 def _git_blob_sha1_from_str (self , s : str , encoding : str = "utf-8" ) -> str :
77- r"""Git blob SHA-1 を文字列から計算する。
82+ r"""Git blob SHA-1 を文字列から計算する.
83+
7884 - s: テキスト文字列(例:"Hello\n")
7985 - encoding: バイト化に使用するエンコーディング.
8086 """
8187 data = s .encode (encoding )
8288 header = f"blob { len (data )} \0 " .encode ()
8389 full = header + data
84- return hashlib .sha1 (full ).hexdigest ()
90+ return hashlib .sha1 (full , usedforsecurity = False ).hexdigest ()
8591
86- def _call_tool_sync (self , tool : str , args : dict [str , Any ]) -> Any :
87- # tool_name = tool.split('_', 1)[1]
92+ def _call_tool_sync (self , tool : str , args : dict [str , Any ]) -> object :
8893 tool_name = tool
8994
90- async def coro_fn (session ) :
95+ async def coro_fn (session : ClientSession ) -> object :
9196 return await session .call_tool (tool_name , args )
9297
9398 result = self ._run_with_session (coro_fn )
@@ -98,7 +103,7 @@ async def coro_fn(session):
98103 try :
99104 obj = json .loads (content .text )
100105 results .append (obj )
101- except Exception :
106+ except ( json . JSONDecodeError , ValueError ) :
102107 results .append (content .text )
103108 elif isinstance (content , EmbeddedResource ):
104109 resource = content .resource
@@ -113,13 +118,13 @@ async def coro_fn(session):
113118
114119 return results [0 ] if len (results ) == 1 else results
115120
116- def _list_tools_sync (self ):
117- async def coro_fn (session ) :
121+ def _list_tools_sync (self ) -> object :
122+ async def coro_fn (session : ClientSession ) -> object :
118123 return await session .list_tools ()
119124
120125 return self ._run_with_session (coro_fn )
121126
122- def _get_tools_sync (self ):
127+ def _get_tools_sync (self ) -> tuple [ str , list [ Tool ]] :
123128 mcp_name = self .server_config .get ("mcp_server_name" , "" )
124129 tools = self .list_tools ().tools
125130 return mcp_name , tools
@@ -149,25 +154,31 @@ def get_function_calling_functions(self) -> list[dict[str, Any]]:
149154 for tool in tools
150155 ]
151156
152- def _get_system_prompt_sync (self ):
157+ def _get_system_prompt_sync (self ) -> str :
153158 mcp_name , tools = self ._get_tools_sync ()
154159 prompt_lines = [f"### { mcp_name } mcp tools" ]
155- for tool in tools :
156- if isinstance (tool , Tool ):
157- tool = {
158- "name" : tool .name ,
159- "description" : tool .description ,
160- "inputSchema" : tool .inputSchema if isinstance (tool .inputSchema , dict ) else {},
161- "required" : tool .inputSchema .get ("required" , []),
160+ for tool_obj in tools :
161+ if isinstance (tool_obj , Tool ):
162+ tool_dict = {
163+ "name" : tool_obj .name ,
164+ "description" : tool_obj .description ,
165+ "inputSchema" : (
166+ tool_obj .inputSchema
167+ if isinstance (tool_obj .inputSchema , dict )
168+ else {}
169+ ),
170+ "required" : tool_obj .inputSchema .get ("required" , []),
162171 }
163- if not isinstance (tool , dict ):
172+ else :
173+ tool_dict = tool_obj
174+ if not isinstance (tool_dict , dict ):
164175 continue
165- tool_name = f"{ mcp_name } _{ tool .get ('name' , '' )} "
166- desc = tool .get ("description" , "" ) or ""
176+ tool_name = f"{ mcp_name } _{ tool_dict .get ('name' , '' )} "
177+ desc = tool_dict .get ("description" , "" ) or ""
167178 desc = desc .replace ("\n " , " " ).replace ("\r " , " " ).strip ()
168- input_schema = tool .get ("inputSchema" , {})
179+ input_schema = tool_dict .get ("inputSchema" , {})
169180 params = input_schema .get ("properties" , {}) if isinstance (input_schema , dict ) else {}
170- required = tool .get ("required" , []) or []
181+ required = tool_dict .get ("required" , []) or []
171182 param_str = (
172183 "{ "
173184 + ", " .join (
0 commit comments