22import asyncio
33import json
44import logging
5+ import os
6+ import re
57from typing import Dict , Optional , Any
68from datetime import datetime
79import subprocess
@@ -28,6 +30,27 @@ def __init__(self, server: MCPServer):
2830 self .status : MCPServerStatus = MCPServerStatus .DISCONNECTED
2931 self .error_message : Optional [str ] = None
3032
33+ def _expand_env_vars (self , value : str ) -> str :
34+ """Expand environment variables in a string."""
35+ # Replace ${VAR} and $VAR patterns with environment variable values
36+ def replace_env_var (match ):
37+ var_name = match .group (1 ) or match .group (2 )
38+ return os .environ .get (var_name , match .group (0 ))
39+
40+ # Match ${VAR} or $VAR patterns
41+ pattern = r'\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)'
42+ return re .sub (pattern , replace_env_var , value )
43+
44+ def _expand_config_env_vars (self , config : Any ) -> Any :
45+ """Recursively expand environment variables in configuration."""
46+ if isinstance (config , str ):
47+ return self ._expand_env_vars (config )
48+ elif isinstance (config , dict ):
49+ return {k : self ._expand_config_env_vars (v ) for k , v in config .items ()}
50+ elif isinstance (config , list ):
51+ return [self ._expand_config_env_vars (item ) for item in config ]
52+ return config
53+
3154 async def connect (self ) -> bool :
3255 """Connect to the MCP server."""
3356 try :
@@ -48,21 +71,36 @@ async def _connect_stdio(self) -> bool:
4871 if not self .server .command :
4972 raise ValueError ("Command is required for stdio transport" )
5073
51- args = [self .server .command ]
74+ # Expand environment variables in command and args
75+ command = self ._expand_env_vars (self .server .command )
76+ args = [command ]
5277 if self .server .args :
53- args .extend (self .server .args )
78+ expanded_args = [self ._expand_env_vars (arg ) for arg in self .server .args ]
79+ args .extend (expanded_args )
5480
5581 # Start the process
56- self .process = await asyncio .create_subprocess_exec (
57- * args ,
58- stdin = asyncio .subprocess .PIPE ,
59- stdout = asyncio .subprocess .PIPE ,
60- stderr = asyncio .subprocess .PIPE
61- )
82+ try :
83+ self .process = await asyncio .create_subprocess_exec (
84+ * args ,
85+ stdin = asyncio .subprocess .PIPE ,
86+ stdout = asyncio .subprocess .PIPE ,
87+ stderr = asyncio .subprocess .PIPE ,
88+ env = {** os .environ } # Pass current environment variables
89+ )
90+ except FileNotFoundError :
91+ raise RuntimeError (f"Command not found: { command } " )
92+ except Exception as e :
93+ raise RuntimeError (f"Failed to start process: { str (e )} " )
6294
6395 self .reader = self .process .stdout
6496 self .writer = self .process .stdin
6597
98+ # Check if process started successfully
99+ await asyncio .sleep (0.1 ) # Give process time to start
100+ if self .process .returncode is not None :
101+ stderr = await self .process .stderr .read ()
102+ raise RuntimeError (f"Process exited immediately with code { self .process .returncode } : { stderr .decode ()} " )
103+
66104 # Start reading messages
67105 self ._read_task = asyncio .create_task (self ._read_messages ())
68106
@@ -96,7 +134,34 @@ async def _connect_http_sse(self) -> bool:
96134 if not self .server .url :
97135 raise ValueError ("URL is required for HTTP+SSE transport" )
98136
99- self .session = aiohttp .ClientSession (headers = self .server .headers or {})
137+ # Expand environment variables in URL
138+ url = self ._expand_env_vars (self .server .url )
139+
140+ # Build headers with authentication
141+ headers = self ._expand_config_env_vars (dict (self .server .headers or {}))
142+ if self .server .auth_type and self .server .auth_config :
143+ # Expand environment variables in auth config
144+ auth_config = self ._expand_config_env_vars (self .server .auth_config )
145+
146+ if self .server .auth_type == "bearer" :
147+ token = auth_config .get ("token" )
148+ if token :
149+ headers ["Authorization" ] = f"Bearer { token } "
150+ elif self .server .auth_type == "api_key" :
151+ key_name = auth_config .get ("key_name" , "X-API-Key" )
152+ api_key = auth_config .get ("api_key" )
153+ if api_key :
154+ headers [key_name ] = api_key
155+ elif self .server .auth_type == "basic" :
156+ username = auth_config .get ("username" )
157+ password = auth_config .get ("password" )
158+ if username and password :
159+ import base64
160+ auth_str = f"{ username } :{ password } "
161+ b64_auth = base64 .b64encode (auth_str .encode ()).decode ()
162+ headers ["Authorization" ] = f"Basic { b64_auth } "
163+
164+ self .session = aiohttp .ClientSession (headers = headers )
100165
101166 # Send initialize request
102167 response = await self ._send_http_request ("initialize" , {
@@ -193,11 +258,23 @@ async def _send_http_request(self, method: str, params: Dict[str, Any]) -> Optio
193258 "params" : params
194259 }
195260
196- async with self .session .post (self .server .url , json = message ) as response :
197- if response .status == 200 :
198- return await response .json ()
199- else :
200- raise RuntimeError (f"HTTP error: { response .status } " )
261+ # Use the expanded URL
262+ url = self ._expand_env_vars (self .server .url )
263+
264+ try :
265+ async with self .session .post (url , json = message , timeout = aiohttp .ClientTimeout (total = 30 )) as response :
266+ if response .status == 200 :
267+ result = await response .json ()
268+ if "error" in result :
269+ raise RuntimeError (f"JSON-RPC error: { result ['error' ]} " )
270+ return result .get ("result" )
271+ else :
272+ text = await response .text ()
273+ raise RuntimeError (f"HTTP error { response .status } : { text } " )
274+ except asyncio .TimeoutError :
275+ raise RuntimeError ("Request timed out" )
276+ except aiohttp .ClientError as e :
277+ raise RuntimeError (f"Connection error: { str (e )} " )
201278
202279 async def _read_messages (self ):
203280 """Read messages from stdio."""
@@ -285,6 +362,7 @@ class MCPServerManager:
285362
286363 def __init__ (self ):
287364 self .connections : Dict [str , MCPConnection ] = {}
365+ self .connection_callbacks : Dict [str , Any ] = {}
288366
289367 async def connect_server (self , server : MCPServer ) -> bool :
290368 """Connect to an MCP server."""
@@ -331,6 +409,40 @@ async def disconnect_all(self):
331409 """Disconnect from all MCP servers."""
332410 for server_name in list (self .connections .keys ()):
333411 await self .disconnect_server (server_name )
412+
413+ def get_server_status (self , server_name : str ) -> Dict [str , Any ]:
414+ """Get the status of a specific server."""
415+ if server_name not in self .connections :
416+ return {
417+ "status" : MCPServerStatus .DISCONNECTED .value ,
418+ "error_message" : None ,
419+ "capabilities" : None ,
420+ "tools" : None ,
421+ "resources" : None ,
422+ "prompts" : None
423+ }
424+
425+ connection = self .connections [server_name ]
426+ return {
427+ "status" : connection .status .value ,
428+ "error_message" : connection .error_message ,
429+ "capabilities" : connection .server .capabilities ,
430+ "tools" : connection .server .tools ,
431+ "resources" : connection .server .resources ,
432+ "prompts" : connection .server .prompts
433+ }
434+
435+ def set_connection_callback (self , server_name : str , callback : Any ):
436+ """Set a callback for connection status changes."""
437+ self .connection_callbacks [server_name ] = callback
438+
439+ async def _notify_status_change (self , server_name : str ):
440+ """Notify about server status change."""
441+ if server_name in self .connection_callbacks :
442+ callback = self .connection_callbacks [server_name ]
443+ if callback :
444+ status = self .get_server_status (server_name )
445+ await callback (server_name , status )
334446
335447
336448# Global instance
0 commit comments