Skip to content

Commit e328af3

Browse files
committed
feat: add watchdog for runtime session handling
1 parent 775726b commit e328af3

9 files changed

Lines changed: 1912 additions & 1403 deletions

File tree

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ cython_debug/
165165
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166166
# and can be added to the global gitignore or merged into this file. For a more nuclear
167167
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
168-
#.idea/
168+
.idea/
169+
170+
.vscode/
171+
172+
.claude/
169173

170174
# Ruff stuff:
171175
.ruff_cache/
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"version": "2.0",
3+
"resources": []
4+
}

samples/mcp-math-server/mcp.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"servers": {
3-
"math-server": {
3+
"coded-math-mcp": {
44
"transport": "stdio",
55
"command": "python",
66
"args": ["server.py"]

samples/mcp-math-server/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ version = "0.0.1"
44
description = "Advanced Math Operations MCP Server"
55
authors = [{ name = "John Doe" }]
66
dependencies = [
7-
"uipath-mcp>=0.0.101",
7+
"uipath-mcp>=0.1.4",
88
]
99
requires-python = ">=3.11"

samples/mcp-math-server/uv.lock

Lines changed: 1387 additions & 1311 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/uipath_mcp/_cli/_runtime/_runtime.py

Lines changed: 109 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@
3737
from .._utils._config import McpServer
3838
from ._context import UiPathServerType
3939
from ._exception import McpErrorCode, UiPathMcpRuntimeError
40-
from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer
40+
from ._session import (
41+
BaseSessionServer,
42+
SessionHealthInfo,
43+
StdioSessionServer,
44+
StreamableHttpSessionServer,
45+
)
4146
from ._token_refresh import TokenRefresher
47+
from ._watchdog import SessionWatchdog
4248

4349
logger = logging.getLogger(__name__)
4450
tracer = trace.get_tracer(__name__)
@@ -86,6 +92,7 @@ def __init__(
8692
self._http_stderr_drain_task: asyncio.Task[None] | None = None
8793
self._http_server_stderr_lines: list[str] = []
8894
self._uipath = UiPath()
95+
self._watchdog: SessionWatchdog | None = None
8996
self._token_refresher: TokenRefresher | None = None
9097
self._cleanup_done = False
9198

@@ -118,6 +125,38 @@ def _validate_auth(self) -> None:
118125
UiPathErrorCategory.SYSTEM,
119126
)
120127

128+
def get_sessions(self) -> dict[str, SessionHealthInfo]:
129+
"""Return health info for all active sessions (SessionProvider protocol)."""
130+
return {
131+
sid: session.get_health_info()
132+
for sid, session in self._session_servers.items()
133+
}
134+
135+
async def remove_session(self, session_id: str, reason: str) -> None:
136+
"""Pop, stop, and clean up a single session (SessionProvider protocol)."""
137+
session_server = self._session_servers.pop(session_id, None)
138+
if session_server is None:
139+
return
140+
141+
logger.warning(f"Removing session {session_id}: {reason}")
142+
143+
try:
144+
await session_server.stop()
145+
except Exception:
146+
logger.error(
147+
f"Error stopping session {session_id}",
148+
exc_info=True,
149+
)
150+
151+
if session_server.output:
152+
if self.sandboxed:
153+
self._session_output = session_server.output
154+
else:
155+
logger.info(f"Session {session_id} output: {session_server.output}")
156+
157+
if self.sandboxed:
158+
self._cancel_event.set()
159+
121160
async def get_schema(self) -> UiPathRuntimeSchema:
122161
"""Get schema for this MCP runtime.
123162
@@ -240,6 +279,9 @@ async def _run_server(self) -> UiPathRuntimeResult:
240279
run_task = asyncio.create_task(self._signalr_client.run())
241280
cancel_task = asyncio.create_task(self._cancel_event.wait())
242281
self._keep_alive_task = asyncio.create_task(self._keep_alive())
282+
283+
self._watchdog = SessionWatchdog(self)
284+
self._watchdog.start()
243285
self._token_refresher.start()
244286

245287
try:
@@ -253,8 +295,8 @@ async def _run_server(self) -> UiPathRuntimeResult:
253295
)
254296
self._cancel_event.set()
255297
finally:
256-
# Cancel any pending tasks gracefully
257-
for task in [run_task, cancel_task, self._keep_alive_task]:
298+
# Cancel pending tasks
299+
for task in [run_task, cancel_task]:
258300
if task and not task.done():
259301
task.cancel()
260302
try:
@@ -280,7 +322,7 @@ async def _run_server(self) -> UiPathRuntimeResult:
280322
except Exception as e:
281323
if isinstance(e, UiPathMcpRuntimeError):
282324
raise
283-
detail = f"Error: {str(e)}"
325+
detail = f"Error: {e}"
284326
raise UiPathMcpRuntimeError(
285327
UiPathErrorCode.EXECUTION_ERROR,
286328
"MCP Runtime execution failed",
@@ -312,11 +354,12 @@ async def _cleanup(self) -> None:
312354
except asyncio.CancelledError:
313355
pass
314356

315-
for session_id, session_server in list(self._session_servers.items()):
316-
try:
317-
await session_server.stop()
318-
except Exception as e:
319-
logger.error(f"Error cleaning up session server {session_id}: {str(e)}")
357+
if self._watchdog:
358+
await self._watchdog.stop()
359+
self._watchdog = None
360+
361+
for session_id in list(self._session_servers.keys()):
362+
await self.remove_session(session_id, reason="runtime shutdown")
320363

321364
# Stop the shared HTTP server process (streamable-http only)
322365
await self._stop_http_server_process()
@@ -327,46 +370,30 @@ async def _cleanup(self) -> None:
327370
try:
328371
await transport._ws.close()
329372
except Exception as e:
330-
logger.error(f"Error closing SignalR WebSocket: {str(e)}")
373+
logger.error(f"Error closing SignalR WebSocket: {e}")
331374

332375
# Add a small delay to allow the server to shut down gracefully
333376
if sys.platform == "win32":
334377
await asyncio.sleep(0.5)
335378

336379
async def _handle_signalr_session_closed(self, args: list[str]) -> None:
337-
"""
338-
Handle session closed by server.
339-
"""
380+
"""Handle session closed by server."""
381+
if self._cleanup_done:
382+
return
383+
340384
if len(args) < 1:
341385
logger.error(f"Received invalid websocket message arguments: {args}")
342386
return
343387

344388
session_id = args[0]
345-
346389
logger.info(f"Received closed signal for session {session_id}")
347-
348-
try:
349-
session_server = self._session_servers.pop(session_id, None)
350-
if session_server:
351-
await session_server.stop()
352-
if session_server.output:
353-
if self.sandboxed:
354-
self._session_output = session_server.output
355-
else:
356-
logger.info(
357-
f"Session {session_id} output: {session_server.output}"
358-
)
359-
# If this is a sandboxed runtime for a specific session, cancel the execution
360-
if self.sandboxed:
361-
self._cancel_event.set()
362-
363-
except Exception as e:
364-
logger.error(f"Error terminating session {session_id}: {str(e)}")
390+
await self.remove_session(session_id, reason="server closed")
365391

366392
async def _handle_signalr_message(self, args: list[str]) -> None:
367-
"""
368-
Handle incoming SignalR messages.
369-
"""
393+
"""Handle incoming SignalR messages."""
394+
if self._cleanup_done:
395+
return
396+
370397
if len(args) < 2:
371398
logger.error(f"Received invalid websocket message arguments: {args}")
372399
return
@@ -392,7 +419,7 @@ async def _handle_signalr_message(self, args: list[str]) -> None:
392419
await session_server.start()
393420
except Exception as e:
394421
logger.error(
395-
f"Error starting session server for session {session_id}: {str(e)}"
422+
f"Error starting session server for session {session_id}: {e}"
396423
)
397424
await self._on_session_start_error(session_id)
398425
raise
@@ -406,7 +433,7 @@ async def _handle_signalr_message(self, args: list[str]) -> None:
406433

407434
except Exception as e:
408435
logger.error(
409-
f"Error handling websocket notification for session {session_id}: {str(e)}"
436+
f"Error handling websocket notification for session {session_id}: {e}"
410437
)
411438

412439
async def _handle_signalr_error(self, error: Any) -> None:
@@ -421,17 +448,21 @@ async def _handle_signalr_close(self) -> None:
421448
"""Handle SignalR connection close event."""
422449
logger.info("Websocket connection closed.")
423450

424-
async def _start_http_server_process(self) -> None:
425-
"""Spawn the streamable-http server process.
426-
427-
The process is started once and shared across all sessions.
428-
"""
451+
def _get_server_env(self) -> dict[str, str]:
452+
"""Return server env vars, with os.environ merged in for Coded servers."""
429453
env_vars = self._server.env.copy()
430454
if self.server_type is UiPathServerType.Coded:
431455
for name, value in os.environ.items():
432456
if name not in env_vars:
433457
env_vars[name] = value
458+
return env_vars
459+
460+
async def _start_http_server_process(self) -> None:
461+
"""Spawn the streamable-http server process.
434462
463+
The process is started once and shared across all sessions.
464+
"""
465+
env_vars = self._get_server_env()
435466
merged_env = {**os.environ, **env_vars} if env_vars else None
436467
self._http_server_stderr_lines = []
437468
self._http_server_process = await asyncio.create_subprocess_exec(
@@ -472,7 +503,12 @@ async def _wait_for_http_server_ready(
472503

473504
url = self._server.url
474505
if not url:
475-
raise ValueError("streamable-http transport requires url in config")
506+
raise UiPathMcpRuntimeError(
507+
McpErrorCode.CONFIGURATION_ERROR,
508+
"Missing URL for streamable-http server",
509+
"Please specify a 'url' in the server configuration for streamable-http transport.",
510+
UiPathErrorCategory.SYSTEM,
511+
)
476512

477513
for attempt in range(max_retries):
478514
# Check if process has crashed
@@ -561,13 +597,9 @@ async def _monitor_http_server_process(self) -> None:
561597
# Stop all HTTP sessions, they will fail on next request anyway
562598
for session_id, session_server in list(self._session_servers.items()):
563599
if isinstance(session_server, StreamableHttpSessionServer):
564-
try:
565-
await session_server.stop()
566-
except Exception as e:
567-
logger.error(
568-
f"Error stopping session {session_id} after process crash: {e}"
569-
)
570-
self._session_servers.pop(session_id, None)
600+
await self.remove_session(
601+
session_id, reason="http process crash"
602+
)
571603
except asyncio.CancelledError:
572604
pass
573605

@@ -577,14 +609,6 @@ async def _register(self) -> None:
577609
initialization_successful = False
578610
tools_result: ListToolsResult | None = None
579611
server_stderr_output = ""
580-
env_vars = self._server.env
581-
582-
# if server is Coded, include environment variables
583-
if self.server_type is UiPathServerType.Coded:
584-
for name, value in os.environ.items():
585-
# config env variables should have precedence over system ones
586-
if name not in env_vars:
587-
env_vars[name] = value
588612

589613
try:
590614
if self._server.is_streamable_http:
@@ -624,7 +648,7 @@ async def _register(self) -> None:
624648
server_params = StdioServerParameters(
625649
command=self._server.command,
626650
args=self._server.args,
627-
env=env_vars,
651+
env=self._get_server_env(),
628652
)
629653

630654
with tempfile.TemporaryFile(mode="w+b") as stderr_temp_binary:
@@ -754,41 +778,39 @@ async def _on_session_start_error(self, session_id: str) -> None:
754778
f"Error sending session dispose signal to UiPath MCP Server: {e}"
755779
)
756780

781+
async def _on_keep_alive_response(self, response: CompletionMessage) -> None:
782+
"""Handle keep-alive response: log session state, detect orphaned sandboxed runtimes."""
783+
if response.error:
784+
logger.error(f"Error during keep-alive: {response.error}")
785+
return
786+
session_ids = response.result
787+
logger.info(f"Server active sessions: {session_ids}")
788+
runtime_sessions = {}
789+
for sid, s in self._session_servers.items():
790+
health = s.get_health_info()
791+
runtime_sessions[sid] = {
792+
"task_done": health.task_done,
793+
"active_requests": len(s._active_requests),
794+
}
795+
logger.info(f"Runtime active sessions: {runtime_sessions}")
796+
# If there are no active sessions and this is a sandbox environment
797+
# We need to cancel the runtime
798+
# eg: when user kills the agent that triggered the runtime, before we subscribe to events
799+
if not session_ids and self.sandboxed and not self._cancel_event.is_set():
800+
logger.warning("No active sessions, cancelling sandboxed runtime...")
801+
self._cancel_event.set()
802+
757803
async def _keep_alive(self) -> None:
758-
"""
759-
Heartbeat to keep the runtime available.
760-
"""
804+
"""Heartbeat to keep the runtime available."""
761805
try:
762806
while not self._cancel_event.is_set():
763807
try:
764-
765-
async def on_keep_alive_response(
766-
response: CompletionMessage,
767-
) -> None:
768-
if response.error:
769-
logger.error(f"Error during keep-alive: {response.error}")
770-
return
771-
session_ids = response.result
772-
logger.info(f"Active sessions: {session_ids}")
773-
# If there are no active sessions and this is a sandbox environment
774-
# We need to cancel the runtime
775-
# eg: when user kills the agent that triggered the runtime, before we subscribe to events
776-
if (
777-
not session_ids
778-
and self.sandboxed
779-
and not self._cancel_event.is_set()
780-
):
781-
logger.error(
782-
"No active sessions, cancelling sandboxed runtime..."
783-
)
784-
self._cancel_event.set()
785-
786808
if self._signalr_client:
787809
logger.info("Sending keep-alive ping...")
788810
await self._signalr_client.send(
789811
method="OnKeepAlive",
790812
arguments=[],
791-
on_invocation=on_keep_alive_response, # type: ignore
813+
on_invocation=self._on_keep_alive_response, # type: ignore
792814
)
793815
else:
794816
logger.error("SignalR client not initialized during keep-alive")
@@ -806,9 +828,7 @@ async def on_keep_alive_response(
806828
raise
807829

808830
async def _on_runtime_abort(self) -> None:
809-
"""
810-
Sends a runtime abort signalr to terminate all connected sessions.
811-
"""
831+
"""Send a runtime abort request to terminate all connected sessions."""
812832
try:
813833
response = await self._uipath.api_client.request_async(
814834
"POST",
@@ -821,7 +841,7 @@ async def _on_runtime_abort(self) -> None:
821841
)
822842
else:
823843
logger.error(
824-
f"Error sending runtime abort signalr to UiPath MCP Server: {response.status_code} - {response.text}"
844+
f"Error sending runtime abort to UiPath MCP Server: {response.status_code} - {response.text}"
825845
)
826846
except Exception as e:
827847
logger.error(

0 commit comments

Comments
 (0)