1414
1515import asyncio
1616import json
17+ import time
1718import uuid
19+ import threading
20+ from dataclasses import dataclass , field
1821from typing import TYPE_CHECKING , Any , Callable , Optional
1922
2023from fastapi import FastAPI , HTTPException , Request , Response , WebSocket
2124from fastapi .responses import StreamingResponse
22- from google .adk .agents .run_config import RunConfig , StreamingMode
25+ from google .adk .agents .run_config import StreamingMode
2326from google .adk .artifacts import InMemoryArtifactService
2427from google .adk .cli .adk_web_server import RunAgentRequest
25- from google .adk .runners import Runner as GoogleRunner
28+ from google .adk .runners import Runner as GoogleRunner , RunConfig
2629from google .adk .sessions import InMemorySessionService , Session
2730from google .adk .tools .mcp_tool .mcp_session_manager import (
2831 StreamableHTTPConnectionParams ,
3437from veadk import Runner
3538from veadk .utils .logger import get_logger
3639
40+
3741if TYPE_CHECKING :
3842 from veadk import Agent
3943
@@ -48,47 +52,121 @@ class ExtraRoute(BaseModel):
4852 methods : list [str ]
4953
5054
51- class WebsocketSessionManager :
52- def __init__ (self ):
53- # ws id -> ws instance
54- self .connections : dict [str , WebSocket ] = {}
55+ @dataclass
56+ class ClientResource :
57+ websocket : WebSocket
58+ agent : "Agent"
59+ session_service : InMemorySessionService
60+ artifact_service : InMemoryArtifactService
61+ pending_requests : dict [str , asyncio .Future ] = field (default_factory = dict )
62+ last_active_time : float = field (default_factory = time .time )
63+
64+ def update_activity (self ):
65+ self .last_active_time = time .time ()
5566
56- # ws id -> msg id -> ret
57- self .pendings : dict [str , dict [str , asyncio .Future ]] = {}
5867
59- async def call_mcp_http (self , ws_id : str , request : dict ):
68+ class ResourceManager :
69+ def __init__ (self , timeout_seconds : int = 3600 ):
70+ self ._lock : threading .Lock = threading .Lock ()
71+ self .resources : dict [str , ClientResource ] = {}
72+ self .timeout_seconds = timeout_seconds
73+ self .cleanup_task : Optional [asyncio .Task ] = None
74+
75+ def register (
76+ self ,
77+ client_id : str ,
78+ websocket : WebSocket ,
79+ agent : "Agent" ,
80+ session_service : InMemorySessionService ,
81+ artifact_service : InMemoryArtifactService ,
82+ ):
83+ with self ._lock :
84+ self .resources [client_id ] = ClientResource (
85+ websocket = websocket ,
86+ agent = agent ,
87+ session_service = session_service ,
88+ artifact_service = artifact_service ,
89+ )
90+ logger .info (f"client { client_id } registered" )
91+
92+ def get (self , client_id : str ) -> Optional [ClientResource ]:
93+ with self ._lock :
94+ logger .info (f"get { client_id } " )
95+ resource = self .resources .get (client_id )
96+ if resource :
97+ resource .update_activity ()
98+ return resource
99+
100+ async def remove (self , client_id : str ):
101+ if client_id in self .resources :
102+ resource = self .resources .pop (client_id )
103+ try :
104+ await resource .websocket .close ()
105+ for fut in resource .pending_requests .values ():
106+ if not fut .done ():
107+ fut .cancel ()
108+ except Exception as e :
109+ logger .warning (
110+ f"client { client_id } resource websocket close error: { e } "
111+ )
112+ pass
113+
114+ async def start_cleanup_loop (self ):
115+ logger .info ("ResourceManager: active cleanup loop" )
116+ while True :
117+ await asyncio .sleep (60 ) # Check every minute
118+ logger .debug ("cleanup loop running..." )
119+ now = time .time ()
120+ to_remove = []
121+ for client_id , resource in self .resources .items ():
122+ logger .debug (
123+ f"check { client_id } , last_active_time={ resource .last_active_time } , timeout={ self .timeout_seconds } "
124+ )
125+ if now - resource .last_active_time > self .timeout_seconds :
126+ to_remove .append (client_id )
127+
128+ for client_id in to_remove :
129+ logger .info (f"Removing inactive client { client_id } " )
130+ await self .remove (client_id )
131+
132+ def start (self ):
133+ self .cleanup_task = asyncio .create_task (self .start_cleanup_loop ())
134+
135+ def stop (self ):
136+ if self .cleanup_task :
137+ self .cleanup_task .cancel ()
138+
139+ async def call_mcp_http (self , client_id : str , request : dict ):
60140 """Forward MCP request to client."""
61- try :
62- ws = self .connections [ws_id ]
63- except KeyError :
64- logger .error (f"Websocket { ws_id } not found" )
141+ resource = self .get (client_id )
142+ if not resource :
143+ logger .error (f"Client { client_id } not found" )
65144 return b""
66145
67- msg = {}
68-
69- msg ["id" ] = str (uuid .uuid4 ())
70- msg ["type" ] = "http_request"
71- msg ["payload" ] = request
146+ ws = resource .websocket
147+ msg = {"id" : str (uuid .uuid4 ()), "type" : "http_request" , "payload" : request }
72148
73149 fut = asyncio .get_event_loop ().create_future ()
74150
75- if ws_id not in self .pendings :
76- self .pendings [ws_id ] = {}
77-
78- self .pendings [ws_id ][msg ["id" ]] = fut
151+ resource .pending_requests [msg ["id" ]] = fut
79152
80153 await ws .send_text (json .dumps (msg ))
81154 return await fut
82155
83- async def handle_ws_message (self , ws_id : str , raw : str ):
156+ async def handle_ws_message (self , client_id : str , raw : str ):
157+ resource = self .get (client_id )
158+ if not resource :
159+ return
160+
84161 msg = json .loads (raw )
85162 if msg .get ("type" ) != "http_response" :
86163 return
87164
88165 req_id = msg ["id" ]
89- fut = self . pendings [ ws_id ] .pop (req_id , None )
166+ fut = resource . pending_requests .pop (req_id , None )
90167 if fut :
91168 fut .set_result (msg )
169+ # todo : 异常ID处理
92170
93171
94172class ServerWithReverseMCP :
@@ -102,27 +180,25 @@ def __init__(
102180 extra_routes : list [ExtraRoute ] | None = None ,
103181 ):
104182 self .agent = agent
105-
106183 self .host = host
107184 self .port = port
108-
109185 self .extra_routes = extra_routes
110186
111- self .app = FastAPI (
112- openapi_url = None ,
113- docs_url = None ,
114- redoc_url = None ,
115- swagger_ui_oauth2_redirect_url = None ,
116- )
187+ self .app = FastAPI ()
117188
118189 self .artifact_service = InMemoryArtifactService ()
190+ self .resource_manager = ResourceManager ()
119191
120192 # build routes for self.app
121193 self .build ()
122194
123- self .ws_session_mgr = WebsocketSessionManager ()
124- self .ws_agent_mgr : dict [str , "Agent" ] = {}
125- self .ws_session_service_mgr : dict [str , "InMemorySessionService" ] = {}
195+ @self .app .on_event ("startup" )
196+ async def startup_event ():
197+ self .resource_manager .start ()
198+
199+ @self .app .on_event ("shutdown" )
200+ async def shutdown_event ():
201+ self .resource_manager .stop ()
126202
127203 def build (self ):
128204 logger .info ("Build routes for server with reverse mcp" )
@@ -149,9 +225,18 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
149225 session_id = payload .session_id
150226 prompt = payload .prompt
151227
152- agent = self .ws_agent_mgr [payload .websocket_id ]
228+ resource = self .resource_manager .get (payload .websocket_id )
229+ if not resource :
230+ raise HTTPException (
231+ status_code = 404 , detail = f"Client { payload .websocket_id } not found"
232+ )
233+ agent = resource .agent
153234
154- runner = Runner (app_name = payload .app_name , agent = agent )
235+ runner = Runner (
236+ app_name = payload .app_name ,
237+ agent = agent ,
238+ session_service = resource .session_service ,
239+ )
155240 response = await runner .run (
156241 messages = [prompt ],
157242 user_id = user_id ,
@@ -160,6 +245,12 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
160245
161246 return InvokeResponse (response = response )
162247
248+ @self .app .delete ("/management/clients/{client_id}" )
249+ async def delete_client (client_id : str ):
250+ """Manually remove a client resource."""
251+ await self .resource_manager .remove (client_id )
252+ return {"status" : "success" , "client_id" : client_id }
253+
163254 # build websocket endpoint
164255 @self .app .websocket ("/ws" )
165256 async def ws_endpoint (ws : WebSocket ):
@@ -179,15 +270,10 @@ async def ws_endpoint(ws: WebSocket):
179270 filters = [t .strip () for t in filters_str .split ("," ) if t .strip ()]
180271
181272 logger .info (f"Register websocket { client_id } to session manager." )
182- self .ws_session_mgr .connections [client_id ] = ws
183273
184274 logger .info (f"Fork agent for websocket { client_id } " )
185275 agent = self .agent .clone ()
186276
187- logger .info (
188- f"clone agent \n model_name={ agent .model_name } \n instruction={ agent .instruction } \n "
189- )
190-
191277 # Mount MCPToolset when creating agent
192278 mcp_toolset_url = f"http://127.0.0.1:{ self .port } /mcp"
193279 mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY : client_id }
@@ -201,10 +287,18 @@ async def ws_endpoint(ws: WebSocket):
201287 tool_filter = filters ,
202288 )
203289 )
204- self .ws_agent_mgr [client_id ] = agent
205290
206291 logger .info (f"Create session service for websocket { client_id } " )
207- self .ws_session_service_mgr [client_id ] = InMemorySessionService ()
292+ session_service = InMemorySessionService ()
293+ artifact_service = InMemoryArtifactService ()
294+
295+ self .resource_manager .register (
296+ client_id = client_id ,
297+ websocket = ws ,
298+ agent = agent ,
299+ session_service = session_service ,
300+ artifact_service = artifact_service ,
301+ )
208302
209303 await ws .accept ()
210304 logger .info (f"Websocket { client_id } connected" )
@@ -213,7 +307,7 @@ async def ws_endpoint(ws: WebSocket):
213307 while True :
214308 raw = await ws .receive_text ()
215309 logger .debug (f"ws.receive_text() = { raw } " )
216- await self .ws_session_mgr .handle_ws_message (client_id , raw )
310+ await self .resource_manager .handle_ws_message (client_id , raw )
217311 except Exception as e :
218312 logger .warning (f"client { client_id } web socket connection closed: { e } " )
219313
@@ -227,12 +321,12 @@ class RunAgentRequestWithWsId(RunAgentRequest):
227321
228322 def _get_session_service (websocket_id : str ) -> InMemorySessionService :
229323 """Get session service for the websocket client."""
230- if websocket_id not in self .ws_session_service_mgr :
324+ resource = self .resource_manager .get (websocket_id )
325+ if not resource :
231326 raise HTTPException (
232- status_code = 404 ,
233- detail = f"WebSocket client { websocket_id } not found" ,
327+ status_code = 404 , detail = f"WebSocket client { websocket_id } not found"
234328 )
235- return self . ws_session_service_mgr [ websocket_id ]
329+ return resource . session_service
236330
237331 @self .app .post (
238332 "/apps/{app_name}/users/{user_id}/sessions" ,
@@ -291,11 +385,18 @@ async def create_session_with_id(
291385 return session
292386
293387 @self .app .post ("/run_sse" )
294- async def run_agent_sse (
295- req : RunAgentRequestWithWsId ,
296- ) -> StreamingResponse :
388+ async def run_agent_sse (req : RunAgentRequestWithWsId ) -> StreamingResponse :
297389 """Run agent with SSE streaming."""
298- session_service = _get_session_service (req .websocket_id )
390+ resource = self .resource_manager .get (req .websocket_id )
391+ if not resource :
392+ raise HTTPException (
393+ status_code = 404 ,
394+ detail = f"WebSocket client { req .websocket_id } not found" ,
395+ )
396+
397+ session_service = resource .session_service
398+ agent = resource .agent
399+ logger .debug (f"Using agent from websocket { req .websocket_id } " )
299400
300401 # Get session
301402 session = await session_service .get_session (
@@ -306,16 +407,6 @@ async def run_agent_sse(
306407 if not session :
307408 raise HTTPException (status_code = 404 , detail = "Session not found" )
308409
309- # Get agent for this websocket
310- if req .websocket_id in self .ws_agent_mgr :
311- agent = self .ws_agent_mgr [req .websocket_id ]
312- logger .debug (f"Using agent from websocket { req .websocket_id } " )
313- else :
314- raise HTTPException (
315- status_code = 404 ,
316- detail = f"WebSocket client { req .websocket_id } not found" ,
317- )
318-
319410 # Create runner
320411 runner = GoogleRunner (
321412 agent = agent ,
@@ -354,10 +445,7 @@ async def event_generator():
354445 content_event .actions .artifact_delta = {}
355446 artifact_event = event .model_copy (deep = True )
356447 artifact_event .content = None
357- events_to_stream = [
358- content_event ,
359- artifact_event ,
360- ]
448+ events_to_stream = [content_event , artifact_event ]
361449
362450 for event_to_stream in events_to_stream :
363451 sse_event = event_to_stream .model_dump_json (
@@ -367,7 +455,7 @@ async def event_generator():
367455 yield f"data: { sse_event } \n \n "
368456 except Exception as e :
369457 logger .exception (f"Error in event_generator: { e } " )
370- yield f"data: { json .dumps ({'error' : 'Internal server error' })} \n \n "
458+ yield f"data: { json .dumps ({'error' : str ( e ) })} \n \n "
371459
372460 return StreamingResponse (
373461 event_generator (),
@@ -391,8 +479,7 @@ async def mcp_proxy(path: str, request: Request):
391479 if not client_id :
392480 return Response ("client id not found" , status_code = 400 )
393481
394- ws = self .ws_session_mgr .connections .get (client_id )
395- if not ws :
482+ if not self .resource_manager .get (client_id ):
396483 return Response ("websocket `client_id` not connected" , status_code = 503 )
397484
398485 body = await request .body ()
@@ -409,7 +496,7 @@ async def mcp_proxy(path: str, request: Request):
409496
410497 logger .debug (f"[Reverse mcp proxy] Request from agent: { payload } " )
411498
412- resp = await self .ws_session_mgr .call_mcp_http (client_id , payload )
499+ resp = await self .resource_manager .call_mcp_http (client_id , payload )
413500
414501 logger .debug (f"[Reverse mcp proxy] Response from local: { resp } " )
415502
0 commit comments