Skip to content

Commit 1cf19e6

Browse files
Userclaude
andcommitted
feat: enhance MCP server support with authentication and environment variables
- Add authentication fields (auth_type, auth_config) to MCP server model - Support bearer token, API key, and basic authentication for remote servers - Add environment variable expansion for server configurations - Enhance error handling and connection resilience - Improve WebSocket status reporting with error messages - Add timeout and better error messages for HTTP requests - Better process management for stdio connections 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e6d0c3d commit 1cf19e6

File tree

4 files changed

+143
-18
lines changed

4 files changed

+143
-18
lines changed

backend/app/alembic/versions/2025_05_27_0402_add_mcp_server_table.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@
1414

1515
# revision identifiers, used by Alembic.
1616
revision: str = '2025_05_27_0402'
17-
down_revision: Union[str, None] = '2025_05_26_1343'
17+
down_revision: Union[str, None] = '1a31ce608336'
1818
branch_labels: Union[str, Sequence[str], None] = None
1919
depends_on: Union[str, Sequence[str], None] = None
2020

2121

2222
def upgrade() -> None:
23-
# Create enum type for transport
24-
op.execute("CREATE TYPE mcptransporttype AS ENUM ('stdio', 'http_sse')")
23+
# Create enum type for transport if it doesn't exist
24+
op.execute("""
25+
DO $$ BEGIN
26+
CREATE TYPE mcptransporttype AS ENUM ('stdio', 'http_sse');
27+
EXCEPTION
28+
WHEN duplicate_object THEN null;
29+
END $$;
30+
""")
2531

2632
# Create mcp_servers table
2733
op.create_table('mcp_servers',
@@ -37,6 +43,8 @@ def upgrade() -> None:
3743
sa.Column('headers', postgresql.JSON(astext_type=sa.Text()), nullable=True),
3844
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='true'),
3945
sa.Column('is_remote', sa.Boolean(), nullable=False, server_default='false'),
46+
sa.Column('auth_type', sa.String(), nullable=True),
47+
sa.Column('auth_config', postgresql.JSON(astext_type=sa.Text()), nullable=True),
4048
sa.Column('capabilities', postgresql.JSON(astext_type=sa.Text()), nullable=True),
4149
sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True),
4250
sa.Column('resources', postgresql.JSON(astext_type=sa.Text()), nullable=True),

backend/app/api/routes/mcp_websocket.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ async def handle_websocket_message(
204204

205205
if server.name in mcp_manager.connections:
206206
connection = mcp_manager.connections[server.name]
207-
status["status"] = connection.server.status.value
207+
status["status"] = connection.status.value
208+
status["error_message"] = connection.error_message
208209
status["capabilities"] = connection.server.capabilities
209210
status["tools"] = connection.server.tools
210211
status["resources"] = connection.server.resources

backend/app/models/mcp_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class MCPServerBase(SQLModel):
3434
headers: Optional[dict[str, str]] = Field(default=None, sa_column=Column(JSON))
3535
is_enabled: bool = Field(default=True)
3636
is_remote: bool = Field(default=False)
37+
auth_type: Optional[str] = Field(default=None, description="Authentication type: 'api_key', 'bearer', 'basic'")
38+
auth_config: Optional[dict[str, str]] = Field(default=None, sa_column=Column(JSON), description="Authentication configuration")
3739

3840

3941
class MCPServer(BaseDBModel, MCPServerBase, table=True):
@@ -63,6 +65,8 @@ class MCPServerUpdate(SQLModel):
6365
headers: Optional[dict[str, str]] = None
6466
is_enabled: Optional[bool] = None
6567
is_remote: Optional[bool] = None
68+
auth_type: Optional[str] = None
69+
auth_config: Optional[dict[str, str]] = None
6670

6771

6872
class MCPServerPublic(MCPServerBase):

backend/app/services/mcp_manager.py

Lines changed: 126 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import asyncio
33
import json
44
import logging
5+
import os
6+
import re
57
from typing import Dict, Optional, Any
68
from datetime import datetime
79
import subprocess
@@ -28,6 +30,27 @@ def __init__(self, server: MCPServer):
2830
self.status: MCPServerStatus = MCPServerStatus.DISCONNECTED
2931
self.error_message: Optional[str] = None
3032

33+
def _expand_env_vars(self, value: str) -> str:
34+
"""Expand environment variables in a string."""
35+
# Replace ${VAR} and $VAR patterns with environment variable values
36+
def replace_env_var(match):
37+
var_name = match.group(1) or match.group(2)
38+
return os.environ.get(var_name, match.group(0))
39+
40+
# Match ${VAR} or $VAR patterns
41+
pattern = r'\$\{([^}]+)\}|\$([A-Za-z_][A-Za-z0-9_]*)'
42+
return re.sub(pattern, replace_env_var, value)
43+
44+
def _expand_config_env_vars(self, config: Any) -> Any:
45+
"""Recursively expand environment variables in configuration."""
46+
if isinstance(config, str):
47+
return self._expand_env_vars(config)
48+
elif isinstance(config, dict):
49+
return {k: self._expand_config_env_vars(v) for k, v in config.items()}
50+
elif isinstance(config, list):
51+
return [self._expand_config_env_vars(item) for item in config]
52+
return config
53+
3154
async def connect(self) -> bool:
3255
"""Connect to the MCP server."""
3356
try:
@@ -48,21 +71,36 @@ async def _connect_stdio(self) -> bool:
4871
if not self.server.command:
4972
raise ValueError("Command is required for stdio transport")
5073

51-
args = [self.server.command]
74+
# Expand environment variables in command and args
75+
command = self._expand_env_vars(self.server.command)
76+
args = [command]
5277
if self.server.args:
53-
args.extend(self.server.args)
78+
expanded_args = [self._expand_env_vars(arg) for arg in self.server.args]
79+
args.extend(expanded_args)
5480

5581
# Start the process
56-
self.process = await asyncio.create_subprocess_exec(
57-
*args,
58-
stdin=asyncio.subprocess.PIPE,
59-
stdout=asyncio.subprocess.PIPE,
60-
stderr=asyncio.subprocess.PIPE
61-
)
82+
try:
83+
self.process = await asyncio.create_subprocess_exec(
84+
*args,
85+
stdin=asyncio.subprocess.PIPE,
86+
stdout=asyncio.subprocess.PIPE,
87+
stderr=asyncio.subprocess.PIPE,
88+
env={**os.environ} # Pass current environment variables
89+
)
90+
except FileNotFoundError:
91+
raise RuntimeError(f"Command not found: {command}")
92+
except Exception as e:
93+
raise RuntimeError(f"Failed to start process: {str(e)}")
6294

6395
self.reader = self.process.stdout
6496
self.writer = self.process.stdin
6597

98+
# Check if process started successfully
99+
await asyncio.sleep(0.1) # Give process time to start
100+
if self.process.returncode is not None:
101+
stderr = await self.process.stderr.read()
102+
raise RuntimeError(f"Process exited immediately with code {self.process.returncode}: {stderr.decode()}")
103+
66104
# Start reading messages
67105
self._read_task = asyncio.create_task(self._read_messages())
68106

@@ -96,7 +134,34 @@ async def _connect_http_sse(self) -> bool:
96134
if not self.server.url:
97135
raise ValueError("URL is required for HTTP+SSE transport")
98136

99-
self.session = aiohttp.ClientSession(headers=self.server.headers or {})
137+
# Expand environment variables in URL
138+
url = self._expand_env_vars(self.server.url)
139+
140+
# Build headers with authentication
141+
headers = self._expand_config_env_vars(dict(self.server.headers or {}))
142+
if self.server.auth_type and self.server.auth_config:
143+
# Expand environment variables in auth config
144+
auth_config = self._expand_config_env_vars(self.server.auth_config)
145+
146+
if self.server.auth_type == "bearer":
147+
token = auth_config.get("token")
148+
if token:
149+
headers["Authorization"] = f"Bearer {token}"
150+
elif self.server.auth_type == "api_key":
151+
key_name = auth_config.get("key_name", "X-API-Key")
152+
api_key = auth_config.get("api_key")
153+
if api_key:
154+
headers[key_name] = api_key
155+
elif self.server.auth_type == "basic":
156+
username = auth_config.get("username")
157+
password = auth_config.get("password")
158+
if username and password:
159+
import base64
160+
auth_str = f"{username}:{password}"
161+
b64_auth = base64.b64encode(auth_str.encode()).decode()
162+
headers["Authorization"] = f"Basic {b64_auth}"
163+
164+
self.session = aiohttp.ClientSession(headers=headers)
100165

101166
# Send initialize request
102167
response = await self._send_http_request("initialize", {
@@ -193,11 +258,23 @@ async def _send_http_request(self, method: str, params: Dict[str, Any]) -> Optio
193258
"params": params
194259
}
195260

196-
async with self.session.post(self.server.url, json=message) as response:
197-
if response.status == 200:
198-
return await response.json()
199-
else:
200-
raise RuntimeError(f"HTTP error: {response.status}")
261+
# Use the expanded URL
262+
url = self._expand_env_vars(self.server.url)
263+
264+
try:
265+
async with self.session.post(url, json=message, timeout=aiohttp.ClientTimeout(total=30)) as response:
266+
if response.status == 200:
267+
result = await response.json()
268+
if "error" in result:
269+
raise RuntimeError(f"JSON-RPC error: {result['error']}")
270+
return result.get("result")
271+
else:
272+
text = await response.text()
273+
raise RuntimeError(f"HTTP error {response.status}: {text}")
274+
except asyncio.TimeoutError:
275+
raise RuntimeError("Request timed out")
276+
except aiohttp.ClientError as e:
277+
raise RuntimeError(f"Connection error: {str(e)}")
201278

202279
async def _read_messages(self):
203280
"""Read messages from stdio."""
@@ -285,6 +362,7 @@ class MCPServerManager:
285362

286363
def __init__(self):
287364
self.connections: Dict[str, MCPConnection] = {}
365+
self.connection_callbacks: Dict[str, Any] = {}
288366

289367
async def connect_server(self, server: MCPServer) -> bool:
290368
"""Connect to an MCP server."""
@@ -331,6 +409,40 @@ async def disconnect_all(self):
331409
"""Disconnect from all MCP servers."""
332410
for server_name in list(self.connections.keys()):
333411
await self.disconnect_server(server_name)
412+
413+
def get_server_status(self, server_name: str) -> Dict[str, Any]:
414+
"""Get the status of a specific server."""
415+
if server_name not in self.connections:
416+
return {
417+
"status": MCPServerStatus.DISCONNECTED.value,
418+
"error_message": None,
419+
"capabilities": None,
420+
"tools": None,
421+
"resources": None,
422+
"prompts": None
423+
}
424+
425+
connection = self.connections[server_name]
426+
return {
427+
"status": connection.status.value,
428+
"error_message": connection.error_message,
429+
"capabilities": connection.server.capabilities,
430+
"tools": connection.server.tools,
431+
"resources": connection.server.resources,
432+
"prompts": connection.server.prompts
433+
}
434+
435+
def set_connection_callback(self, server_name: str, callback: Any):
436+
"""Set a callback for connection status changes."""
437+
self.connection_callbacks[server_name] = callback
438+
439+
async def _notify_status_change(self, server_name: str):
440+
"""Notify about server status change."""
441+
if server_name in self.connection_callbacks:
442+
callback = self.connection_callbacks[server_name]
443+
if callback:
444+
status = self.get_server_status(server_name)
445+
await callback(server_name, status)
334446

335447

336448
# Global instance

0 commit comments

Comments
 (0)