Skip to content

Commit 2c07433

Browse files
committed
feat: add watchdog for runtime session handling
1 parent 775726b commit 2c07433

5 files changed

Lines changed: 484 additions & 3 deletions

File tree

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ cython_debug/
165165
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166166
# and can be added to the global gitignore or merged into this file. For a more nuclear
167167
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
168-
#.idea/
168+
.idea/
169+
170+
.vscode/
171+
172+
.claude/
169173

170174
# Ruff stuff:
171175
.ruff_cache/

src/uipath_mcp/_cli/_runtime/_runtime.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,14 @@
3737
from .._utils._config import McpServer
3838
from ._context import UiPathServerType
3939
from ._exception import McpErrorCode, UiPathMcpRuntimeError
40-
from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer
40+
from ._session import (
41+
BaseSessionServer,
42+
SessionHealthInfo,
43+
StdioSessionServer,
44+
StreamableHttpSessionServer,
45+
)
4146
from ._token_refresh import TokenRefresher
47+
from ._watchdog import SessionWatchdog
4248

4349
logger = logging.getLogger(__name__)
4450
tracer = trace.get_tracer(__name__)
@@ -86,6 +92,7 @@ def __init__(
8692
self._http_stderr_drain_task: asyncio.Task[None] | None = None
8793
self._http_server_stderr_lines: list[str] = []
8894
self._uipath = UiPath()
95+
self._watchdog: SessionWatchdog | None = None
8996
self._token_refresher: TokenRefresher | None = None
9097
self._cleanup_done = False
9198

@@ -118,6 +125,46 @@ def _validate_auth(self) -> None:
118125
UiPathErrorCategory.SYSTEM,
119126
)
120127

128+
def get_sessions(self) -> dict[str, SessionHealthInfo]:
129+
"""Return health info for all active sessions (SessionProvider protocol)."""
130+
return {
131+
sid: session.get_health_info()
132+
for sid, session in self._session_servers.items()
133+
}
134+
135+
async def remove_session(self, session_id: str, reason: str) -> None:
136+
"""Remove and stop a session by ID (SessionProvider protocol)."""
137+
session_server = self._session_servers.pop(session_id, None)
138+
if session_server is not None:
139+
logger.warning(
140+
f"Removing session {session_id}: {reason}"
141+
)
142+
try:
143+
await session_server.stop()
144+
except Exception:
145+
logger.error(
146+
f"Error stopping session {session_id} during watchdog removal",
147+
exc_info=True,
148+
)
149+
await self._close_session_on_server(session_id)
150+
151+
async def _close_session_on_server(self, session_id: str) -> None:
152+
"""Notify the UiPath server to remove a session so it stops sending messages."""
153+
try:
154+
await self._uipath.api_client.request_async(
155+
"DELETE",
156+
f"agenthub_/mcp/{self._folder_key}/{self.slug}",
157+
headers={"mcp-session-id": session_id},
158+
)
159+
logger.info(f"Notified server of session closure: {session_id}")
160+
except HTTPStatusError as e:
161+
if e.response.status_code == 404:
162+
logger.info(f"Session {session_id} already removed server-side")
163+
else:
164+
logger.error(f"Error closing session {session_id} on server: {e}")
165+
except Exception as e:
166+
logger.error(f"Error closing session {session_id} on server: {e}")
167+
121168
async def get_schema(self) -> UiPathRuntimeSchema:
122169
"""Get schema for this MCP runtime.
123170
@@ -240,6 +287,9 @@ async def _run_server(self) -> UiPathRuntimeResult:
240287
run_task = asyncio.create_task(self._signalr_client.run())
241288
cancel_task = asyncio.create_task(self._cancel_event.wait())
242289
self._keep_alive_task = asyncio.create_task(self._keep_alive())
290+
291+
self._watchdog = SessionWatchdog(self)
292+
self._watchdog.start()
243293
self._token_refresher.start()
244294

245295
try:
@@ -312,6 +362,10 @@ async def _cleanup(self) -> None:
312362
except asyncio.CancelledError:
313363
pass
314364

365+
if self._watchdog:
366+
await self._watchdog.stop()
367+
self._watchdog = None
368+
315369
for session_id, session_server in list(self._session_servers.items()):
316370
try:
317371
await session_server.stop()
@@ -367,6 +421,10 @@ async def _handle_signalr_message(self, args: list[str]) -> None:
367421
"""
368422
Handle incoming SignalR messages.
369423
"""
424+
425+
if self._cleanup_done:
426+
return
427+
370428
if len(args) < 2:
371429
logger.error(f"Received invalid websocket message arguments: {args}")
372430
return
@@ -769,7 +827,15 @@ async def on_keep_alive_response(
769827
logger.error(f"Error during keep-alive: {response.error}")
770828
return
771829
session_ids = response.result
772-
logger.info(f"Active sessions: {session_ids}")
830+
logger.info(f"Server active sessions: {session_ids}")
831+
runtime_sessions = {}
832+
for sid, s in self._session_servers.items():
833+
health = s.get_health_info()
834+
runtime_sessions[sid] = {
835+
"task_done": health.task_done,
836+
"active_requests": health.active_request_count,
837+
}
838+
logger.info(f"Runtime active sessions: {runtime_sessions}")
773839
# If there are no active sessions and this is a sandbox environment
774840
# We need to cancel the runtime
775841
# eg: when user kills the agent that triggered the runtime, before we subscribe to events
@@ -783,6 +849,7 @@ async def on_keep_alive_response(
783849
)
784850
self._cancel_event.set()
785851

852+
786853
if self._signalr_client:
787854
logger.info("Sending keep-alive ping...")
788855
await self._signalr_client.send(

src/uipath_mcp/_cli/_runtime/_session.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import io
33
import logging
44
import tempfile
5+
import time
56
from abc import ABC, abstractmethod
7+
from dataclasses import dataclass
68
from typing import Any
79

10+
from anyio import EndOfStream
811
from mcp import StdioServerParameters, stdio_client
912
from mcp.client.streamable_http import streamable_http_client
1013
from mcp.shared.message import SessionMessage
@@ -28,6 +31,19 @@
2831
RETRY_DELAY = 1
2932

3033

34+
@dataclass
35+
class SessionHealthInfo:
36+
"""Health information for a session, used by the watchdog."""
37+
38+
session_id: str
39+
transport_type: str
40+
task_done: bool
41+
task_exception: BaseException | None
42+
last_activity_time: float
43+
queue_size: int
44+
active_request_count: int
45+
46+
3147
class BaseSessionServer(ABC):
3248
"""Base class with transport-agnostic message relay logic."""
3349

@@ -48,9 +64,16 @@ def __init__(
4864
self._active_requests: dict[str, str] = {}
4965
self._last_request_id: str | None = None
5066
self._last_message_id: str | None = None
67+
self._last_activity_time: float = time.monotonic()
5168
self._uipath = uipath
5269
self._mcp_tracer = McpTracer(tracer, logger)
5370

71+
@property
72+
@abstractmethod
73+
def transport_type(self) -> str:
74+
"""Returns the transport type identifier (e.g. 'stdio', 'streamable-http')."""
75+
...
76+
5477
@property
5578
@abstractmethod
5679
def output(self) -> str | None:
@@ -79,8 +102,28 @@ async def stop(self) -> None:
79102
self._read_stream = None
80103
self._write_stream = None
81104

105+
def get_health_info(self) -> SessionHealthInfo:
106+
"""Return health information for this session."""
107+
task_done = self._run_task.done() if self._run_task else True
108+
task_exception: BaseException | None = None
109+
if task_done and self._run_task is not None:
110+
try:
111+
task_exception = self._run_task.exception()
112+
except (asyncio.CancelledError, asyncio.InvalidStateError):
113+
pass
114+
return SessionHealthInfo(
115+
session_id=self._session_id,
116+
transport_type=self.transport_type,
117+
task_done=task_done,
118+
task_exception=task_exception,
119+
last_activity_time=self._last_activity_time,
120+
queue_size=self._message_queue.qsize(),
121+
active_request_count=len(self._active_requests),
122+
)
123+
82124
async def on_message_received(self, request_id: str) -> None:
83125
"""Get new incoming messages from UiPath MCP Server."""
126+
self._last_activity_time = time.monotonic()
84127
for attempt in range(MAX_RETRIES + 1):
85128
try:
86129
await self._get_messages_internal(request_id)
@@ -115,6 +158,7 @@ async def _relay_messages(self) -> None:
115158
break
116159

117160
session_message = await self._read_stream.receive()
161+
self._last_activity_time = time.monotonic()
118162
if isinstance(session_message, Exception):
119163
logger.error(f"Received error: {session_message}")
120164
continue
@@ -137,6 +181,11 @@ async def _relay_messages(self) -> None:
137181
# For non-responses, use the last known request_id
138182
if self._last_request_id is not None:
139183
await self._send_message(message, self._last_request_id)
184+
except EndOfStream:
185+
logger.warning(
186+
f"Read stream closed for session {self._session_id}"
187+
)
188+
break
140189
except Exception as e:
141190
if session_message:
142191
logger.info(session_message)
@@ -292,6 +341,10 @@ class StdioSessionServer(BaseSessionServer):
292341

293342
_server_stderr_output: str | None = None
294343

344+
@property
345+
def transport_type(self) -> str:
346+
return "stdio"
347+
295348
@property
296349
def output(self) -> str | None:
297350
"""Returns the captured stderr output from the MCP server process."""
@@ -345,6 +398,10 @@ async def _run_server(self, server_params: StdioServerParameters) -> None:
345398
class StreamableHttpSessionServer(BaseSessionServer):
346399
"""Manages an HTTP connection to a shared streamable-http server for a specific session."""
347400

401+
@property
402+
def transport_type(self) -> str:
403+
return "streamable-http"
404+
348405
@property
349406
def output(self) -> str | None:
350407
"""Returns captured output from the server process, if any."""

0 commit comments

Comments
 (0)