|
| 1 | +"""WebSocket router for per-session streaming agent chat. |
| 2 | +
|
| 3 | +Endpoint: |
| 4 | + WS /ws/sessions/{session_id}/chat?token=<JWT> |
| 5 | +
|
| 6 | +Client → Server message types: |
| 7 | + {"type": "message", "content": "..."} |
| 8 | + {"type": "interrupt"} |
| 9 | + {"type": "ping"} |
| 10 | +
|
| 11 | +Server → Client message types: |
| 12 | + {"type": "text_delta", "content": "..."} |
| 13 | + {"type": "tool_use_start", "tool_name": "...", "tool_input": {...}} |
| 14 | + {"type": "tool_result", "tool_name": "...", "content": "..."} |
| 15 | + {"type": "thinking", "content": "..."} |
| 16 | + {"type": "cost_update", "cost_usd": 0.003, "input_tokens": 100, "output_tokens": 50} |
| 17 | + {"type": "done"} |
| 18 | + {"type": "error", "message": "..."} |
| 19 | + {"type": "pong"} |
| 20 | +
|
| 21 | +Architecture note: This router intentionally handles session orchestration |
| 22 | +(activation, interrupt coordination, cost aggregation) rather than delegating |
| 23 | +to a core service. A future refactoring (TODO: extract to |
| 24 | +core.session_chat_service) would make this a thin transport adapter. See |
| 25 | +GitHub issue #502 for context. |
| 26 | +
|
| 27 | +Workspace scoping: This endpoint is intentionally exempt from workspace_path |
| 28 | +query param validation. Auth is already scoped to a user via JWT and the |
| 29 | +session_id identifies the resource — workspace scoping on a per-session WS |
| 30 | +would be redundant. Revisit if multi-tenant workspace isolation is required. |
| 31 | +""" |
| 32 | + |
| 33 | +import asyncio |
| 34 | +import json |
| 35 | +import logging |
| 36 | +from typing import Optional |
| 37 | + |
| 38 | +import jwt as pyjwt |
| 39 | +from fastapi import APIRouter, WebSocket, WebSocketDisconnect |
| 40 | +from sqlalchemy import select |
| 41 | + |
| 42 | +from codeframe.auth.manager import SECRET, JWT_ALGORITHM, JWT_AUDIENCE, get_async_session_maker |
| 43 | +from codeframe.auth.models import User |
| 44 | +from codeframe.ui.shared import session_chat_manager |
| 45 | + |
| 46 | +logger = logging.getLogger(__name__) |
| 47 | + |
| 48 | +router = APIRouter(tags=["websocket"]) |
| 49 | + |
| 50 | + |
| 51 | +# --------------------------------------------------------------------------- |
| 52 | +# Helpers |
| 53 | +# --------------------------------------------------------------------------- |
| 54 | + |
| 55 | + |
| 56 | +async def _authenticate_websocket(websocket: WebSocket) -> Optional[int]: |
| 57 | + """Validate JWT from query param. Returns user_id or closes with 1008.""" |
| 58 | + token = websocket.query_params.get("token") |
| 59 | + if not token: |
| 60 | + await websocket.close(code=1008, reason="Authentication required: missing token") |
| 61 | + return None |
| 62 | + |
| 63 | + try: |
| 64 | + payload = pyjwt.decode(token, SECRET, algorithms=[JWT_ALGORITHM], audience=JWT_AUDIENCE) |
| 65 | + user_id_str = payload.get("sub") |
| 66 | + if not user_id_str: |
| 67 | + await websocket.close(code=1008, reason="Invalid token: missing subject") |
| 68 | + return None |
| 69 | + user_id = int(user_id_str) |
| 70 | + except pyjwt.ExpiredSignatureError: |
| 71 | + await websocket.close(code=1008, reason="Token expired") |
| 72 | + return None |
| 73 | + except (pyjwt.InvalidTokenError, ValueError) as exc: |
| 74 | + logger.debug("WebSocket JWT decode error: %s", exc) |
| 75 | + await websocket.close(code=1008, reason="Invalid authentication token") |
| 76 | + return None |
| 77 | + |
| 78 | + try: |
| 79 | + async_session_maker = get_async_session_maker() |
| 80 | + async with async_session_maker() as session: |
| 81 | + result = await session.execute(select(User).where(User.id == user_id)) |
| 82 | + user = result.scalar_one_or_none() |
| 83 | + if user is None: |
| 84 | + await websocket.close(code=1008, reason="User not found") |
| 85 | + return None |
| 86 | + if not user.is_active: |
| 87 | + await websocket.close(code=1008, reason="User is inactive") |
| 88 | + return None |
| 89 | + except Exception as exc: |
| 90 | + logger.error("WebSocket user lookup error: %s", exc) |
| 91 | + await websocket.close(code=1008, reason="Authentication failed") |
| 92 | + return None |
| 93 | + |
| 94 | + return user_id |
| 95 | + |
| 96 | + |
| 97 | +async def _run_streaming_adapter( |
| 98 | + session_id: str, |
| 99 | + user_message: str, |
| 100 | + token_queue: asyncio.Queue, |
| 101 | + interrupt_event: asyncio.Event, |
| 102 | +) -> None: |
| 103 | + """Stub for the real Anthropic streaming adapter (tracked in #503). |
| 104 | +
|
| 105 | + Signature is the integration point — replace the body when the real adapter |
| 106 | + is ready. Must push dicts matching the server→client protocol into |
| 107 | + token_queue, and check interrupt_event.is_set() periodically. |
| 108 | + """ |
| 109 | + words = user_message.split() or ["..."] |
| 110 | + for word in words: |
| 111 | + if interrupt_event.is_set(): |
| 112 | + await token_queue.put({"type": "done"}) |
| 113 | + return |
| 114 | + await token_queue.put({"type": "text_delta", "content": word + " "}) |
| 115 | + await asyncio.sleep(0.01) |
| 116 | + |
| 117 | + await token_queue.put( |
| 118 | + {"type": "cost_update", "cost_usd": 0.0, "input_tokens": 0, "output_tokens": 0} |
| 119 | + ) |
| 120 | + await token_queue.put({"type": "done"}) |
| 121 | + |
| 122 | + |
| 123 | +# --------------------------------------------------------------------------- |
| 124 | +# Endpoint |
| 125 | +# --------------------------------------------------------------------------- |
| 126 | + |
| 127 | + |
| 128 | +@router.websocket("/ws/sessions/{session_id}/chat") |
| 129 | +async def session_chat_ws(session_id: str, websocket: WebSocket) -> None: |
| 130 | + """Bidirectional WebSocket for streaming agent chat on an interactive session.""" |
| 131 | + # --- Auth --- |
| 132 | + user_id = await _authenticate_websocket(websocket) |
| 133 | + if user_id is None: |
| 134 | + return |
| 135 | + |
| 136 | + # --- Session validation --- |
| 137 | + db = getattr(websocket.app.state, "db", None) |
| 138 | + if db is None: |
| 139 | + await websocket.close(code=1011, reason="Database unavailable") |
| 140 | + return |
| 141 | + |
| 142 | + session = await asyncio.to_thread(db.interactive_sessions.get, session_id) |
| 143 | + if session is None or session.get("state") == "ended": |
| 144 | + await websocket.close(code=4008, reason="Session not found or ended") |
| 145 | + return |
| 146 | + |
| 147 | + # --- Accept connection; everything after this point must run inside the |
| 148 | + # try/finally so unregister() and close() always execute even if |
| 149 | + # update_state, register, or get_token_queue raises. --- |
| 150 | + await websocket.accept() |
| 151 | + |
| 152 | + relay: Optional[asyncio.Task] = None |
| 153 | + adapter_task: list[Optional[asyncio.Task]] = [None] |
| 154 | + |
| 155 | + try: |
| 156 | + await asyncio.to_thread(db.interactive_sessions.update_state, session_id, "active") |
| 157 | + await session_chat_manager.register(session_id, websocket) |
| 158 | + |
| 159 | + token_queue = await session_chat_manager.get_token_queue(session_id) |
| 160 | + |
| 161 | + # ---- Relay task: token_queue → WebSocket ---------------------------- |
| 162 | + async def _relay() -> None: |
| 163 | + """Forward adapter events to the client; persist cost on 'done'.""" |
| 164 | + turn_cost = {"cost_usd": 0.0, "input_tokens": 0, "output_tokens": 0} |
| 165 | + try: |
| 166 | + while True: |
| 167 | + event = await token_queue.get() |
| 168 | + event_type = event.get("type") |
| 169 | + |
| 170 | + if event_type == "cost_update": |
| 171 | + turn_cost["cost_usd"] += event.get("cost_usd", 0.0) |
| 172 | + turn_cost["input_tokens"] += event.get("input_tokens", 0) |
| 173 | + turn_cost["output_tokens"] += event.get("output_tokens", 0) |
| 174 | + try: |
| 175 | + await websocket.send_json(event) |
| 176 | + except Exception as exc: |
| 177 | + logger.warning( |
| 178 | + "session_id=%s send_json(cost_update) failed: %s", session_id, exc |
| 179 | + ) |
| 180 | + return |
| 181 | + |
| 182 | + elif event_type == "done": |
| 183 | + # Persist cost BEFORE sending "done" so clients that |
| 184 | + # immediately fetch session stats observe accurate totals. |
| 185 | + if ( |
| 186 | + turn_cost["cost_usd"] |
| 187 | + or turn_cost["input_tokens"] |
| 188 | + or turn_cost["output_tokens"] |
| 189 | + ): |
| 190 | + try: |
| 191 | + await asyncio.to_thread( |
| 192 | + db.interactive_sessions.update_cost, |
| 193 | + session_id, |
| 194 | + turn_cost["cost_usd"], |
| 195 | + turn_cost["input_tokens"], |
| 196 | + turn_cost["output_tokens"], |
| 197 | + ) |
| 198 | + except Exception as exc: |
| 199 | + logger.error( |
| 200 | + "session_id=%s update_cost failed: %s turn_cost=%s", |
| 201 | + session_id, |
| 202 | + exc, |
| 203 | + turn_cost, |
| 204 | + ) |
| 205 | + turn_cost = {"cost_usd": 0.0, "input_tokens": 0, "output_tokens": 0} |
| 206 | + try: |
| 207 | + await websocket.send_json(event) |
| 208 | + except Exception as exc: |
| 209 | + logger.warning( |
| 210 | + "session_id=%s send_json(done) failed: %s", session_id, exc |
| 211 | + ) |
| 212 | + return |
| 213 | + |
| 214 | + else: |
| 215 | + try: |
| 216 | + await websocket.send_json(event) |
| 217 | + except Exception as exc: |
| 218 | + logger.warning( |
| 219 | + "session_id=%s send_json(%s) failed: %s", |
| 220 | + session_id, |
| 221 | + event_type, |
| 222 | + exc, |
| 223 | + ) |
| 224 | + return |
| 225 | + finally: |
| 226 | + # Flush any cost accumulated during a cancelled/aborted turn |
| 227 | + if ( |
| 228 | + turn_cost["cost_usd"] |
| 229 | + or turn_cost["input_tokens"] |
| 230 | + or turn_cost["output_tokens"] |
| 231 | + ): |
| 232 | + try: |
| 233 | + await asyncio.to_thread( |
| 234 | + db.interactive_sessions.update_cost, |
| 235 | + session_id, |
| 236 | + turn_cost["cost_usd"], |
| 237 | + turn_cost["input_tokens"], |
| 238 | + turn_cost["output_tokens"], |
| 239 | + ) |
| 240 | + except Exception as exc: |
| 241 | + logger.error( |
| 242 | + "session_id=%s relay finally update_cost failed: %s", session_id, exc |
| 243 | + ) |
| 244 | + |
| 245 | + # ---- Receive task: WebSocket → dispatch ----------------------------- |
| 246 | + async def _receive() -> None: |
| 247 | + """Read client messages and dispatch to ping/interrupt/message handlers.""" |
| 248 | + while True: |
| 249 | + try: |
| 250 | + raw = await websocket.receive_text() |
| 251 | + except WebSocketDisconnect: |
| 252 | + raise |
| 253 | + |
| 254 | + try: |
| 255 | + msg = json.loads(raw) |
| 256 | + except json.JSONDecodeError: |
| 257 | + try: |
| 258 | + await websocket.send_json({"type": "error", "message": "Invalid JSON"}) |
| 259 | + except Exception: |
| 260 | + pass |
| 261 | + continue |
| 262 | + |
| 263 | + msg_type = msg.get("type") |
| 264 | + |
| 265 | + if msg_type == "ping": |
| 266 | + await websocket.send_json({"type": "pong"}) |
| 267 | + |
| 268 | + elif msg_type == "interrupt": |
| 269 | + await session_chat_manager.signal_interrupt(session_id) |
| 270 | + |
| 271 | + elif msg_type == "message": |
| 272 | + content = msg.get("content", "") |
| 273 | + |
| 274 | + # Cancel any in-flight adapter |
| 275 | + if adapter_task[0] and not adapter_task[0].done(): |
| 276 | + adapter_task[0].cancel() |
| 277 | + try: |
| 278 | + await adapter_task[0] |
| 279 | + except (asyncio.CancelledError, Exception) as exc: |
| 280 | + logger.debug( |
| 281 | + "session_id=%s adapter cancelled: %s", session_id, exc |
| 282 | + ) |
| 283 | + |
| 284 | + # Reset interrupt and drain stale queue items |
| 285 | + await session_chat_manager.reset_interrupt(session_id) |
| 286 | + while not token_queue.empty(): |
| 287 | + try: |
| 288 | + token_queue.get_nowait() |
| 289 | + except asyncio.QueueEmpty: |
| 290 | + break |
| 291 | + |
| 292 | + interrupt_event = await session_chat_manager.get_interrupt_event(session_id) |
| 293 | + adapter_task[0] = asyncio.create_task( |
| 294 | + _run_streaming_adapter(session_id, content, token_queue, interrupt_event) |
| 295 | + ) |
| 296 | + |
| 297 | + relay = asyncio.create_task(_relay()) |
| 298 | + await _receive() |
| 299 | + |
| 300 | + except WebSocketDisconnect: |
| 301 | + logger.debug("Session chat WebSocket disconnected: session_id=%s", session_id) |
| 302 | + except Exception as exc: |
| 303 | + logger.error("Session chat WebSocket error: %s", exc, exc_info=True) |
| 304 | + try: |
| 305 | + await websocket.send_json({"type": "error", "message": str(exc)}) |
| 306 | + except Exception: |
| 307 | + pass |
| 308 | + finally: |
| 309 | + if relay is not None: |
| 310 | + relay.cancel() |
| 311 | + try: |
| 312 | + await relay |
| 313 | + except (asyncio.CancelledError, Exception): |
| 314 | + pass |
| 315 | + if adapter_task[0] and not adapter_task[0].done(): |
| 316 | + adapter_task[0].cancel() |
| 317 | + try: |
| 318 | + await adapter_task[0] |
| 319 | + except (asyncio.CancelledError, Exception) as exc: |
| 320 | + logger.debug("session_id=%s adapter cleanup: %s", session_id, exc) |
| 321 | + await session_chat_manager.unregister(session_id, websocket) |
| 322 | + try: |
| 323 | + await websocket.close() |
| 324 | + except Exception: |
| 325 | + pass |
0 commit comments