Skip to content

Commit b7fca4d

Browse files
authored
feat(client): add multi-tenant support for client (#518)
1 parent 3534e0f commit b7fca4d

File tree

1 file changed

+158
-71
lines changed

1 file changed

+158
-71
lines changed

veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py

Lines changed: 158 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414

1515
import asyncio
1616
import json
17+
import time
1718
import uuid
19+
import threading
20+
from dataclasses import dataclass, field
1821
from typing import TYPE_CHECKING, Any, Callable, Optional
1922

2023
from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
2124
from fastapi.responses import StreamingResponse
22-
from google.adk.agents.run_config import RunConfig, StreamingMode
25+
from google.adk.agents.run_config import StreamingMode
2326
from google.adk.artifacts import InMemoryArtifactService
2427
from google.adk.cli.adk_web_server import RunAgentRequest
25-
from google.adk.runners import Runner as GoogleRunner
28+
from google.adk.runners import Runner as GoogleRunner, RunConfig
2629
from google.adk.sessions import InMemorySessionService, Session
2730
from google.adk.tools.mcp_tool.mcp_session_manager import (
2831
StreamableHTTPConnectionParams,
@@ -34,6 +37,7 @@
3437
from veadk import Runner
3538
from veadk.utils.logger import get_logger
3639

40+
3741
if TYPE_CHECKING:
3842
from veadk import Agent
3943

@@ -48,47 +52,121 @@ class ExtraRoute(BaseModel):
4852
methods: list[str]
4953

5054

51-
class WebsocketSessionManager:
52-
def __init__(self):
53-
# ws id -> ws instance
54-
self.connections: dict[str, WebSocket] = {}
55+
@dataclass
56+
class ClientResource:
57+
websocket: WebSocket
58+
agent: "Agent"
59+
session_service: InMemorySessionService
60+
artifact_service: InMemoryArtifactService
61+
pending_requests: dict[str, asyncio.Future] = field(default_factory=dict)
62+
last_active_time: float = field(default_factory=time.time)
63+
64+
def update_activity(self):
65+
self.last_active_time = time.time()
5566

56-
# ws id -> msg id -> ret
57-
self.pendings: dict[str, dict[str, asyncio.Future]] = {}
5867

59-
async def call_mcp_http(self, ws_id: str, request: dict):
68+
class ResourceManager:
69+
def __init__(self, timeout_seconds: int = 3600):
70+
self._lock: threading.Lock = threading.Lock()
71+
self.resources: dict[str, ClientResource] = {}
72+
self.timeout_seconds = timeout_seconds
73+
self.cleanup_task: Optional[asyncio.Task] = None
74+
75+
def register(
76+
self,
77+
client_id: str,
78+
websocket: WebSocket,
79+
agent: "Agent",
80+
session_service: InMemorySessionService,
81+
artifact_service: InMemoryArtifactService,
82+
):
83+
with self._lock:
84+
self.resources[client_id] = ClientResource(
85+
websocket=websocket,
86+
agent=agent,
87+
session_service=session_service,
88+
artifact_service=artifact_service,
89+
)
90+
logger.info(f"client {client_id} registered")
91+
92+
def get(self, client_id: str) -> Optional[ClientResource]:
93+
with self._lock:
94+
logger.info(f"get {client_id}")
95+
resource = self.resources.get(client_id)
96+
if resource:
97+
resource.update_activity()
98+
return resource
99+
100+
async def remove(self, client_id: str):
101+
if client_id in self.resources:
102+
resource = self.resources.pop(client_id)
103+
try:
104+
await resource.websocket.close()
105+
for fut in resource.pending_requests.values():
106+
if not fut.done():
107+
fut.cancel()
108+
except Exception as e:
109+
logger.warning(
110+
f"client {client_id} resource websocket close error: {e}"
111+
)
112+
pass
113+
114+
async def start_cleanup_loop(self):
115+
logger.info("ResourceManager: active cleanup loop")
116+
while True:
117+
await asyncio.sleep(60) # Check every minute
118+
logger.debug("cleanup loop running...")
119+
now = time.time()
120+
to_remove = []
121+
for client_id, resource in self.resources.items():
122+
logger.debug(
123+
f"check {client_id}, last_active_time={resource.last_active_time}, timeout={self.timeout_seconds}"
124+
)
125+
if now - resource.last_active_time > self.timeout_seconds:
126+
to_remove.append(client_id)
127+
128+
for client_id in to_remove:
129+
logger.info(f"Removing inactive client {client_id}")
130+
await self.remove(client_id)
131+
132+
def start(self):
133+
self.cleanup_task = asyncio.create_task(self.start_cleanup_loop())
134+
135+
def stop(self):
136+
if self.cleanup_task:
137+
self.cleanup_task.cancel()
138+
139+
async def call_mcp_http(self, client_id: str, request: dict):
60140
"""Forward MCP request to client."""
61-
try:
62-
ws = self.connections[ws_id]
63-
except KeyError:
64-
logger.error(f"Websocket {ws_id} not found")
141+
resource = self.get(client_id)
142+
if not resource:
143+
logger.error(f"Client {client_id} not found")
65144
return b""
66145

67-
msg = {}
68-
69-
msg["id"] = str(uuid.uuid4())
70-
msg["type"] = "http_request"
71-
msg["payload"] = request
146+
ws = resource.websocket
147+
msg = {"id": str(uuid.uuid4()), "type": "http_request", "payload": request}
72148

73149
fut = asyncio.get_event_loop().create_future()
74150

75-
if ws_id not in self.pendings:
76-
self.pendings[ws_id] = {}
77-
78-
self.pendings[ws_id][msg["id"]] = fut
151+
resource.pending_requests[msg["id"]] = fut
79152

80153
await ws.send_text(json.dumps(msg))
81154
return await fut
82155

83-
async def handle_ws_message(self, ws_id: str, raw: str):
156+
async def handle_ws_message(self, client_id: str, raw: str):
157+
resource = self.get(client_id)
158+
if not resource:
159+
return
160+
84161
msg = json.loads(raw)
85162
if msg.get("type") != "http_response":
86163
return
87164

88165
req_id = msg["id"]
89-
fut = self.pendings[ws_id].pop(req_id, None)
166+
fut = resource.pending_requests.pop(req_id, None)
90167
if fut:
91168
fut.set_result(msg)
169+
# todo : 异常ID处理
92170

93171

94172
class ServerWithReverseMCP:
@@ -102,27 +180,25 @@ def __init__(
102180
extra_routes: list[ExtraRoute] | None = None,
103181
):
104182
self.agent = agent
105-
106183
self.host = host
107184
self.port = port
108-
109185
self.extra_routes = extra_routes
110186

111-
self.app = FastAPI(
112-
openapi_url=None,
113-
docs_url=None,
114-
redoc_url=None,
115-
swagger_ui_oauth2_redirect_url=None,
116-
)
187+
self.app = FastAPI()
117188

118189
self.artifact_service = InMemoryArtifactService()
190+
self.resource_manager = ResourceManager()
119191

120192
# build routes for self.app
121193
self.build()
122194

123-
self.ws_session_mgr = WebsocketSessionManager()
124-
self.ws_agent_mgr: dict[str, "Agent"] = {}
125-
self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {}
195+
@self.app.on_event("startup")
196+
async def startup_event():
197+
self.resource_manager.start()
198+
199+
@self.app.on_event("shutdown")
200+
async def shutdown_event():
201+
self.resource_manager.stop()
126202

127203
def build(self):
128204
logger.info("Build routes for server with reverse mcp")
@@ -149,9 +225,18 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
149225
session_id = payload.session_id
150226
prompt = payload.prompt
151227

152-
agent = self.ws_agent_mgr[payload.websocket_id]
228+
resource = self.resource_manager.get(payload.websocket_id)
229+
if not resource:
230+
raise HTTPException(
231+
status_code=404, detail=f"Client {payload.websocket_id} not found"
232+
)
233+
agent = resource.agent
153234

154-
runner = Runner(app_name=payload.app_name, agent=agent)
235+
runner = Runner(
236+
app_name=payload.app_name,
237+
agent=agent,
238+
session_service=resource.session_service,
239+
)
155240
response = await runner.run(
156241
messages=[prompt],
157242
user_id=user_id,
@@ -160,6 +245,12 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
160245

161246
return InvokeResponse(response=response)
162247

248+
@self.app.delete("/management/clients/{client_id}")
249+
async def delete_client(client_id: str):
250+
"""Manually remove a client resource."""
251+
await self.resource_manager.remove(client_id)
252+
return {"status": "success", "client_id": client_id}
253+
163254
# build websocket endpoint
164255
@self.app.websocket("/ws")
165256
async def ws_endpoint(ws: WebSocket):
@@ -179,15 +270,10 @@ async def ws_endpoint(ws: WebSocket):
179270
filters = [t.strip() for t in filters_str.split(",") if t.strip()]
180271

181272
logger.info(f"Register websocket {client_id} to session manager.")
182-
self.ws_session_mgr.connections[client_id] = ws
183273

184274
logger.info(f"Fork agent for websocket {client_id}")
185275
agent = self.agent.clone()
186276

187-
logger.info(
188-
f"clone agent \n model_name={agent.model_name}\n instruction={agent.instruction}\n"
189-
)
190-
191277
# Mount MCPToolset when creating agent
192278
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
193279
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id}
@@ -201,10 +287,18 @@ async def ws_endpoint(ws: WebSocket):
201287
tool_filter=filters,
202288
)
203289
)
204-
self.ws_agent_mgr[client_id] = agent
205290

206291
logger.info(f"Create session service for websocket {client_id}")
207-
self.ws_session_service_mgr[client_id] = InMemorySessionService()
292+
session_service = InMemorySessionService()
293+
artifact_service = InMemoryArtifactService()
294+
295+
self.resource_manager.register(
296+
client_id=client_id,
297+
websocket=ws,
298+
agent=agent,
299+
session_service=session_service,
300+
artifact_service=artifact_service,
301+
)
208302

209303
await ws.accept()
210304
logger.info(f"Websocket {client_id} connected")
@@ -213,7 +307,7 @@ async def ws_endpoint(ws: WebSocket):
213307
while True:
214308
raw = await ws.receive_text()
215309
logger.debug(f"ws.receive_text() = {raw}")
216-
await self.ws_session_mgr.handle_ws_message(client_id, raw)
310+
await self.resource_manager.handle_ws_message(client_id, raw)
217311
except Exception as e:
218312
logger.warning(f"client {client_id} web socket connection closed: {e}")
219313

@@ -227,12 +321,12 @@ class RunAgentRequestWithWsId(RunAgentRequest):
227321

228322
def _get_session_service(websocket_id: str) -> InMemorySessionService:
229323
"""Get session service for the websocket client."""
230-
if websocket_id not in self.ws_session_service_mgr:
324+
resource = self.resource_manager.get(websocket_id)
325+
if not resource:
231326
raise HTTPException(
232-
status_code=404,
233-
detail=f"WebSocket client {websocket_id} not found",
327+
status_code=404, detail=f"WebSocket client {websocket_id} not found"
234328
)
235-
return self.ws_session_service_mgr[websocket_id]
329+
return resource.session_service
236330

237331
@self.app.post(
238332
"/apps/{app_name}/users/{user_id}/sessions",
@@ -291,11 +385,18 @@ async def create_session_with_id(
291385
return session
292386

293387
@self.app.post("/run_sse")
294-
async def run_agent_sse(
295-
req: RunAgentRequestWithWsId,
296-
) -> StreamingResponse:
388+
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
297389
"""Run agent with SSE streaming."""
298-
session_service = _get_session_service(req.websocket_id)
390+
resource = self.resource_manager.get(req.websocket_id)
391+
if not resource:
392+
raise HTTPException(
393+
status_code=404,
394+
detail=f"WebSocket client {req.websocket_id} not found",
395+
)
396+
397+
session_service = resource.session_service
398+
agent = resource.agent
399+
logger.debug(f"Using agent from websocket {req.websocket_id}")
299400

300401
# Get session
301402
session = await session_service.get_session(
@@ -306,16 +407,6 @@ async def run_agent_sse(
306407
if not session:
307408
raise HTTPException(status_code=404, detail="Session not found")
308409

309-
# Get agent for this websocket
310-
if req.websocket_id in self.ws_agent_mgr:
311-
agent = self.ws_agent_mgr[req.websocket_id]
312-
logger.debug(f"Using agent from websocket {req.websocket_id}")
313-
else:
314-
raise HTTPException(
315-
status_code=404,
316-
detail=f"WebSocket client {req.websocket_id} not found",
317-
)
318-
319410
# Create runner
320411
runner = GoogleRunner(
321412
agent=agent,
@@ -354,10 +445,7 @@ async def event_generator():
354445
content_event.actions.artifact_delta = {}
355446
artifact_event = event.model_copy(deep=True)
356447
artifact_event.content = None
357-
events_to_stream = [
358-
content_event,
359-
artifact_event,
360-
]
448+
events_to_stream = [content_event, artifact_event]
361449

362450
for event_to_stream in events_to_stream:
363451
sse_event = event_to_stream.model_dump_json(
@@ -367,7 +455,7 @@ async def event_generator():
367455
yield f"data: {sse_event}\n\n"
368456
except Exception as e:
369457
logger.exception(f"Error in event_generator: {e}")
370-
yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n"
458+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
371459

372460
return StreamingResponse(
373461
event_generator(),
@@ -391,8 +479,7 @@ async def mcp_proxy(path: str, request: Request):
391479
if not client_id:
392480
return Response("client id not found", status_code=400)
393481

394-
ws = self.ws_session_mgr.connections.get(client_id)
395-
if not ws:
482+
if not self.resource_manager.get(client_id):
396483
return Response("websocket `client_id` not connected", status_code=503)
397484

398485
body = await request.body()
@@ -409,7 +496,7 @@ async def mcp_proxy(path: str, request: Request):
409496

410497
logger.debug(f"[Reverse mcp proxy] Request from agent: {payload}")
411498

412-
resp = await self.ws_session_mgr.call_mcp_http(client_id, payload)
499+
resp = await self.resource_manager.call_mcp_http(client_id, payload)
413500

414501
logger.debug(f"[Reverse mcp proxy] Response from local: {resp}")
415502

0 commit comments

Comments
 (0)