|
| 1 | +"""Voice tool-call session — persistent socket.io connection to CAS.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import asyncio |
| 6 | +import logging |
| 7 | +import os |
| 8 | +from collections.abc import Awaitable, Callable |
| 9 | +from enum import Enum |
| 10 | +from typing import Any |
| 11 | +from urllib.parse import urlparse |
| 12 | + |
| 13 | +from pydantic import ValidationError |
| 14 | + |
| 15 | +from uipath.core.chat import ( |
| 16 | + UiPathVoiceToolCallMessage, |
| 17 | + UiPathVoiceToolCallRequest, |
| 18 | + UiPathVoiceToolCallResult, |
| 19 | +) |
| 20 | +from uipath.runtime.context import UiPathRuntimeContext |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +_ATTEMPT_CAS_SOCKET_CONNECTION_TIMEOUT_SECONDS = 15.0 |
| 26 | +_INFLIGHT_TOOL_DRAIN_AFTER_AGENT_END_TIMEOUT_SECONDS = 30.0 |
| 27 | + |
| 28 | + |
| 29 | +class VoiceToolCallSessionError(RuntimeError): |
| 30 | + pass |
| 31 | + |
| 32 | + |
| 33 | +class VoiceSessionEndReason(str, Enum): |
| 34 | + COMPLETED = "completed" |
| 35 | + DISCONNECTED = "disconnected" |
| 36 | + READY_EMIT_FAILED = "ready_emit_failed" |
| 37 | + |
| 38 | + |
| 39 | +class VoiceEvent(str, Enum): |
| 40 | + """CAS voice-session protocol events (excludes socket.io lifecycle).""" |
| 41 | + |
| 42 | + TOOL_CALL = "voice_tool_call" # received |
| 43 | + SESSION_ENDED = "voice_session_ended" # received |
| 44 | + TOOLS_READY = "voice_tools_ready" # sent |
| 45 | + TOOL_RESULT = "voice_tool_result" # sent |
| 46 | + |
| 47 | + |
| 48 | +ToolHandler = Callable[ |
| 49 | + [UiPathVoiceToolCallRequest], Awaitable[UiPathVoiceToolCallResult] |
| 50 | +] |
| 51 | + |
| 52 | + |
| 53 | +class VoiceToolCallSession: |
| 54 | + """Socket.io session with CAS for tool-call traffic. |
| 55 | +
|
| 56 | + Receives `voice_tool_call` batches, emits one `voice_tool_result` per |
| 57 | + `callId`, exits on `voice_session_ended` or disconnect. CAS pulls |
| 58 | + agent config from Orchestrator directly; this session carries only |
| 59 | + tool calls. |
| 60 | + """ |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + url: str, |
| 65 | + socketio_path: str, |
| 66 | + headers: dict[str, str], |
| 67 | + tool_handler: ToolHandler, |
| 68 | + ) -> None: |
| 69 | + self._url = url |
| 70 | + self._socketio_path = socketio_path |
| 71 | + self._headers = headers |
| 72 | + self._tool_handler = tool_handler |
| 73 | + self._client: Any = None |
| 74 | + self._done = asyncio.Event() |
| 75 | + self._in_flight: set[asyncio.Task[None]] = set() |
| 76 | + self._end_reason: VoiceSessionEndReason | None = None |
| 77 | + |
| 78 | + async def run(self) -> VoiceSessionEndReason: |
| 79 | + """Connect, dispatch tool calls until session ends, then disconnect. |
| 80 | +
|
| 81 | + Raises: |
| 82 | + VoiceToolCallSessionError: If connecting to CAS fails. |
| 83 | + """ |
| 84 | + from socketio import AsyncClient # type: ignore[import-untyped] |
| 85 | + |
| 86 | + self._client = AsyncClient(logger=False, engineio_logger=False) |
| 87 | + self._client.on("connect", self._handle_connect) |
| 88 | + self._client.on("disconnect", self._handle_disconnect) |
| 89 | + self._client.on(VoiceEvent.TOOL_CALL, self._handle_tool_call) |
| 90 | + self._client.on(VoiceEvent.SESSION_ENDED, self._handle_session_ended) |
| 91 | + |
| 92 | + try: |
| 93 | + await asyncio.wait_for( |
| 94 | + self._client.connect( |
| 95 | + url=self._url, |
| 96 | + socketio_path=self._socketio_path, |
| 97 | + headers=self._headers, |
| 98 | + transports=["websocket"], |
| 99 | + ), |
| 100 | + timeout=_ATTEMPT_CAS_SOCKET_CONNECTION_TIMEOUT_SECONDS, |
| 101 | + ) |
| 102 | + except Exception as exc: |
| 103 | + await self._safe_disconnect("after connect-failure") |
| 104 | + raise VoiceToolCallSessionError( |
| 105 | + f"Failed to connect to CAS voice endpoint: {exc}" |
| 106 | + ) from exc |
| 107 | + |
| 108 | + try: |
| 109 | + await self._done.wait() |
| 110 | + await self._drain_in_flight() |
| 111 | + finally: |
| 112 | + await self._safe_disconnect("on shutdown") |
| 113 | + |
| 114 | + return self._end_reason or VoiceSessionEndReason.DISCONNECTED |
| 115 | + |
| 116 | + async def _safe_disconnect(self, when: str) -> None: |
| 117 | + try: |
| 118 | + await self._client.disconnect() |
| 119 | + except Exception as exc: |
| 120 | + logger.debug("[Voice] disconnect %s raised: %s", when, exc) |
| 121 | + |
| 122 | + def _end_session(self, reason: VoiceSessionEndReason) -> None: |
| 123 | + # First writer wins: a late disconnect must not overwrite COMPLETED. |
| 124 | + if self._end_reason is None: |
| 125 | + self._end_reason = reason |
| 126 | + self._done.set() |
| 127 | + |
| 128 | + async def _drain_in_flight(self) -> None: |
| 129 | + """Wait for in-flight tool tasks to finish, capped by the drain timeout.""" |
| 130 | + if not self._in_flight: |
| 131 | + return |
| 132 | + logger.info( |
| 133 | + "[Voice] Session ended with %d in-flight tool task(s); draining (max %.0fs)", |
| 134 | + len(self._in_flight), |
| 135 | + _INFLIGHT_TOOL_DRAIN_AFTER_AGENT_END_TIMEOUT_SECONDS, |
| 136 | + ) |
| 137 | + try: |
| 138 | + await asyncio.wait_for( |
| 139 | + asyncio.gather(*self._in_flight, return_exceptions=True), |
| 140 | + timeout=_INFLIGHT_TOOL_DRAIN_AFTER_AGENT_END_TIMEOUT_SECONDS, |
| 141 | + ) |
| 142 | + except asyncio.TimeoutError: |
| 143 | + unfinished = sum(1 for t in self._in_flight if not t.done()) |
| 144 | + logger.warning( |
| 145 | + "[Voice] %d tool task(s) did not complete within %.0fs of session end", |
| 146 | + unfinished, |
| 147 | + _INFLIGHT_TOOL_DRAIN_AFTER_AGENT_END_TIMEOUT_SECONDS, |
| 148 | + ) |
| 149 | + |
| 150 | + async def _handle_connect(self) -> None: |
| 151 | + logger.info("[Voice] Socket.io connected to CAS") |
| 152 | + try: |
| 153 | + await self._client.emit(VoiceEvent.TOOLS_READY, {}) |
| 154 | + except Exception as exc: |
| 155 | + # CAS gates tool dispatch on this event; without it the session is dead. |
| 156 | + logger.warning("[Voice] emit voice_tools_ready failed: %s", exc) |
| 157 | + self._end_session(VoiceSessionEndReason.READY_EMIT_FAILED) |
| 158 | + |
| 159 | + async def _handle_disconnect(self) -> None: |
| 160 | + logger.info("[Voice] Socket.io disconnected from CAS") |
| 161 | + self._end_session(VoiceSessionEndReason.DISCONNECTED) |
| 162 | + |
| 163 | + async def _handle_tool_call(self, data: dict[str, Any], *_: Any) -> None: |
| 164 | + """Spawn a task per call and return — the reader must stay free for `voice_session_ended`.""" |
| 165 | + if self._done.is_set(): |
| 166 | + return |
| 167 | + |
| 168 | + try: |
| 169 | + message = UiPathVoiceToolCallMessage.model_validate(data) |
| 170 | + except ValidationError as exc: |
| 171 | + logger.warning("[Voice] invalid voice_tool_call payload: %s", exc) |
| 172 | + return |
| 173 | + |
| 174 | + for call in message.calls: |
| 175 | + task = asyncio.create_task(self._execute_tool_call(call)) |
| 176 | + self._in_flight.add(task) |
| 177 | + task.add_done_callback(self._in_flight.discard) |
| 178 | + |
| 179 | + async def _execute_tool_call(self, call: UiPathVoiceToolCallRequest) -> None: |
| 180 | + """Run one tool call and emit its `voice_tool_result`.""" |
| 181 | + logger.info( |
| 182 | + "[Voice] voice_tool_call dispatched: %s (%s) args=%s", |
| 183 | + call.tool_name, |
| 184 | + call.call_id, |
| 185 | + call.args, |
| 186 | + ) |
| 187 | + try: |
| 188 | + tool_result = await self._tool_handler(call) |
| 189 | + except Exception as exc: |
| 190 | + logger.exception("[Voice] Tool call execution failed: %s", call.tool_name) |
| 191 | + tool_result = UiPathVoiceToolCallResult(result=str(exc), is_error=True) |
| 192 | + |
| 193 | + try: |
| 194 | + await self._client.emit( |
| 195 | + VoiceEvent.TOOL_RESULT, |
| 196 | + {"callId": call.call_id, **tool_result.model_dump(by_alias=True)}, |
| 197 | + ) |
| 198 | + except Exception as exc: |
| 199 | + logger.debug( |
| 200 | + "[Voice] emit voice_tool_result failed for %s: %s", call.call_id, exc |
| 201 | + ) |
| 202 | + return |
| 203 | + logger.info( |
| 204 | + "[Voice] voice_tool_result sent: %s (isError=%s)", |
| 205 | + call.call_id, |
| 206 | + tool_result.is_error, |
| 207 | + ) |
| 208 | + |
| 209 | + async def _handle_session_ended(self, _data: Any, *_: Any) -> None: |
| 210 | + logger.info("[Voice] voice_session_ended received") |
| 211 | + self._end_session(VoiceSessionEndReason.COMPLETED) |
| 212 | + |
| 213 | + |
| 214 | +def get_voice_bridge( |
| 215 | + context: UiPathRuntimeContext, |
| 216 | + tool_handler: ToolHandler, |
| 217 | +) -> VoiceToolCallSession: |
| 218 | + """Factory for a CAS voice tool-call session. |
| 219 | +
|
| 220 | + Raises: |
| 221 | + RuntimeError: If UIPATH_URL is not set or invalid. |
| 222 | + """ |
| 223 | + assert context.conversation_id is not None, "conversation_id must be set in context" |
| 224 | + |
| 225 | + if cas_host := os.environ.get("CAS_WEBSOCKET_HOST"): |
| 226 | + url = f"ws://{cas_host}?conversationId={context.conversation_id}" |
| 227 | + socketio_path = "/socket.io" |
| 228 | + logger.warning( |
| 229 | + f"CAS_WEBSOCKET_HOST is set. Using websocket_url '{url}{socketio_path}'." |
| 230 | + ) |
| 231 | + else: |
| 232 | + base_url = os.environ.get("UIPATH_URL") |
| 233 | + if not base_url: |
| 234 | + raise RuntimeError( |
| 235 | + "UIPATH_URL environment variable required for conversational mode" |
| 236 | + ) |
| 237 | + parsed = urlparse(base_url) |
| 238 | + if not parsed.netloc: |
| 239 | + raise RuntimeError(f"Invalid UIPATH_URL format: {base_url}") |
| 240 | + url = f"wss://{parsed.netloc}?conversationId={context.conversation_id}" |
| 241 | + socketio_path = "autopilotforeveryone_/websocket_/socket.io" |
| 242 | + |
| 243 | + headers = { |
| 244 | + "Authorization": f"Bearer {os.environ.get('UIPATH_ACCESS_TOKEN', '')}", |
| 245 | + "X-UiPath-Internal-TenantId": context.tenant_id |
| 246 | + or os.environ.get("UIPATH_TENANT_ID", ""), |
| 247 | + "X-UiPath-Internal-AccountId": context.org_id |
| 248 | + or os.environ.get("UIPATH_ORGANIZATION_ID", ""), |
| 249 | + "X-UiPath-ConversationId": context.conversation_id, |
| 250 | + } |
| 251 | + |
| 252 | + return VoiceToolCallSession( |
| 253 | + url=url, |
| 254 | + socketio_path=socketio_path, |
| 255 | + headers=headers, |
| 256 | + tool_handler=tool_handler, |
| 257 | + ) |
0 commit comments