Skip to content

Commit 8eefe6f

Browse files
authored
feat(api): WebSocket endpoint for streaming agent chat (#502)
Adds /ws/sessions/{session_id}/chat WebSocket endpoint for bidirectional streaming agent chat. - JWT auth (close 1008 on failure), session validation (close 4008 for ended/unknown) - Concurrent receive + relay tasks for non-blocking interrupt handling - Cost/token DB update persisted before sending "done" (accurate client reads) - Cleanup guard wraps all post-accept logic in try/finally - SessionChatManager tracks connections, interrupt events, and token queues per session; unregister() checks websocket identity to prevent late disconnects from tearing newer state - Stub _run_streaming_adapter() ready for real Anthropic integration (#503) - 12 integration tests covering auth, session validation, protocol, interrupt, cost, cleanup Closes #502
1 parent ed5d899 commit 8eefe6f

4 files changed

Lines changed: 668 additions & 0 deletions

File tree

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""WebSocket router for per-session streaming agent chat.
2+
3+
Endpoint:
4+
WS /ws/sessions/{session_id}/chat?token=<JWT>
5+
6+
Client → Server message types:
7+
{"type": "message", "content": "..."}
8+
{"type": "interrupt"}
9+
{"type": "ping"}
10+
11+
Server → Client message types:
12+
{"type": "text_delta", "content": "..."}
13+
{"type": "tool_use_start", "tool_name": "...", "tool_input": {...}}
14+
{"type": "tool_result", "tool_name": "...", "content": "..."}
15+
{"type": "thinking", "content": "..."}
16+
{"type": "cost_update", "cost_usd": 0.003, "input_tokens": 100, "output_tokens": 50}
17+
{"type": "done"}
18+
{"type": "error", "message": "..."}
19+
{"type": "pong"}
20+
21+
Architecture note: This router intentionally handles session orchestration
22+
(activation, interrupt coordination, cost aggregation) rather than delegating
23+
to a core service. A future refactoring (TODO: extract to
24+
core.session_chat_service) would make this a thin transport adapter. See
25+
GitHub issue #502 for context.
26+
27+
Workspace scoping: This endpoint is intentionally exempt from workspace_path
28+
query param validation. Auth is already scoped to a user via JWT and the
29+
session_id identifies the resource — workspace scoping on a per-session WS
30+
would be redundant. Revisit if multi-tenant workspace isolation is required.
31+
"""
32+
33+
import asyncio
34+
import json
35+
import logging
36+
from typing import Optional
37+
38+
import jwt as pyjwt
39+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
40+
from sqlalchemy import select
41+
42+
from codeframe.auth.manager import SECRET, JWT_ALGORITHM, JWT_AUDIENCE, get_async_session_maker
43+
from codeframe.auth.models import User
44+
from codeframe.ui.shared import session_chat_manager
45+
46+
logger = logging.getLogger(__name__)
47+
48+
router = APIRouter(tags=["websocket"])
49+
50+
51+
# ---------------------------------------------------------------------------
52+
# Helpers
53+
# ---------------------------------------------------------------------------
54+
55+
56+
async def _authenticate_websocket(websocket: WebSocket) -> Optional[int]:
57+
"""Validate JWT from query param. Returns user_id or closes with 1008."""
58+
token = websocket.query_params.get("token")
59+
if not token:
60+
await websocket.close(code=1008, reason="Authentication required: missing token")
61+
return None
62+
63+
try:
64+
payload = pyjwt.decode(token, SECRET, algorithms=[JWT_ALGORITHM], audience=JWT_AUDIENCE)
65+
user_id_str = payload.get("sub")
66+
if not user_id_str:
67+
await websocket.close(code=1008, reason="Invalid token: missing subject")
68+
return None
69+
user_id = int(user_id_str)
70+
except pyjwt.ExpiredSignatureError:
71+
await websocket.close(code=1008, reason="Token expired")
72+
return None
73+
except (pyjwt.InvalidTokenError, ValueError) as exc:
74+
logger.debug("WebSocket JWT decode error: %s", exc)
75+
await websocket.close(code=1008, reason="Invalid authentication token")
76+
return None
77+
78+
try:
79+
async_session_maker = get_async_session_maker()
80+
async with async_session_maker() as session:
81+
result = await session.execute(select(User).where(User.id == user_id))
82+
user = result.scalar_one_or_none()
83+
if user is None:
84+
await websocket.close(code=1008, reason="User not found")
85+
return None
86+
if not user.is_active:
87+
await websocket.close(code=1008, reason="User is inactive")
88+
return None
89+
except Exception as exc:
90+
logger.error("WebSocket user lookup error: %s", exc)
91+
await websocket.close(code=1008, reason="Authentication failed")
92+
return None
93+
94+
return user_id
95+
96+
97+
async def _run_streaming_adapter(
98+
session_id: str,
99+
user_message: str,
100+
token_queue: asyncio.Queue,
101+
interrupt_event: asyncio.Event,
102+
) -> None:
103+
"""Stub for the real Anthropic streaming adapter (tracked in #503).
104+
105+
Signature is the integration point — replace the body when the real adapter
106+
is ready. Must push dicts matching the server→client protocol into
107+
token_queue, and check interrupt_event.is_set() periodically.
108+
"""
109+
words = user_message.split() or ["..."]
110+
for word in words:
111+
if interrupt_event.is_set():
112+
await token_queue.put({"type": "done"})
113+
return
114+
await token_queue.put({"type": "text_delta", "content": word + " "})
115+
await asyncio.sleep(0.01)
116+
117+
await token_queue.put(
118+
{"type": "cost_update", "cost_usd": 0.0, "input_tokens": 0, "output_tokens": 0}
119+
)
120+
await token_queue.put({"type": "done"})
121+
122+
123+
# ---------------------------------------------------------------------------
124+
# Endpoint
125+
# ---------------------------------------------------------------------------
126+
127+
128+
@router.websocket("/ws/sessions/{session_id}/chat")
129+
async def session_chat_ws(session_id: str, websocket: WebSocket) -> None:
130+
"""Bidirectional WebSocket for streaming agent chat on an interactive session."""
131+
# --- Auth ---
132+
user_id = await _authenticate_websocket(websocket)
133+
if user_id is None:
134+
return
135+
136+
# --- Session validation ---
137+
db = getattr(websocket.app.state, "db", None)
138+
if db is None:
139+
await websocket.close(code=1011, reason="Database unavailable")
140+
return
141+
142+
session = await asyncio.to_thread(db.interactive_sessions.get, session_id)
143+
if session is None or session.get("state") == "ended":
144+
await websocket.close(code=4008, reason="Session not found or ended")
145+
return
146+
147+
# --- Accept connection; everything after this point must run inside the
148+
# try/finally so unregister() and close() always execute even if
149+
# update_state, register, or get_token_queue raises. ---
150+
await websocket.accept()
151+
152+
relay: Optional[asyncio.Task] = None
153+
adapter_task: list[Optional[asyncio.Task]] = [None]
154+
155+
try:
156+
await asyncio.to_thread(db.interactive_sessions.update_state, session_id, "active")
157+
await session_chat_manager.register(session_id, websocket)
158+
159+
token_queue = await session_chat_manager.get_token_queue(session_id)
160+
161+
# ---- Relay task: token_queue → WebSocket ----------------------------
162+
async def _relay() -> None:
163+
"""Forward adapter events to the client; persist cost on 'done'."""
164+
turn_cost = {"cost_usd": 0.0, "input_tokens": 0, "output_tokens": 0}
165+
try:
166+
while True:
167+
event = await token_queue.get()
168+
event_type = event.get("type")
169+
170+
if event_type == "cost_update":
171+
turn_cost["cost_usd"] += event.get("cost_usd", 0.0)
172+
turn_cost["input_tokens"] += event.get("input_tokens", 0)
173+
turn_cost["output_tokens"] += event.get("output_tokens", 0)
174+
try:
175+
await websocket.send_json(event)
176+
except Exception as exc:
177+
logger.warning(
178+
"session_id=%s send_json(cost_update) failed: %s", session_id, exc
179+
)
180+
return
181+
182+
elif event_type == "done":
183+
# Persist cost BEFORE sending "done" so clients that
184+
# immediately fetch session stats observe accurate totals.
185+
if (
186+
turn_cost["cost_usd"]
187+
or turn_cost["input_tokens"]
188+
or turn_cost["output_tokens"]
189+
):
190+
try:
191+
await asyncio.to_thread(
192+
db.interactive_sessions.update_cost,
193+
session_id,
194+
turn_cost["cost_usd"],
195+
turn_cost["input_tokens"],
196+
turn_cost["output_tokens"],
197+
)
198+
except Exception as exc:
199+
logger.error(
200+
"session_id=%s update_cost failed: %s turn_cost=%s",
201+
session_id,
202+
exc,
203+
turn_cost,
204+
)
205+
turn_cost = {"cost_usd": 0.0, "input_tokens": 0, "output_tokens": 0}
206+
try:
207+
await websocket.send_json(event)
208+
except Exception as exc:
209+
logger.warning(
210+
"session_id=%s send_json(done) failed: %s", session_id, exc
211+
)
212+
return
213+
214+
else:
215+
try:
216+
await websocket.send_json(event)
217+
except Exception as exc:
218+
logger.warning(
219+
"session_id=%s send_json(%s) failed: %s",
220+
session_id,
221+
event_type,
222+
exc,
223+
)
224+
return
225+
finally:
226+
# Flush any cost accumulated during a cancelled/aborted turn
227+
if (
228+
turn_cost["cost_usd"]
229+
or turn_cost["input_tokens"]
230+
or turn_cost["output_tokens"]
231+
):
232+
try:
233+
await asyncio.to_thread(
234+
db.interactive_sessions.update_cost,
235+
session_id,
236+
turn_cost["cost_usd"],
237+
turn_cost["input_tokens"],
238+
turn_cost["output_tokens"],
239+
)
240+
except Exception as exc:
241+
logger.error(
242+
"session_id=%s relay finally update_cost failed: %s", session_id, exc
243+
)
244+
245+
# ---- Receive task: WebSocket → dispatch -----------------------------
246+
async def _receive() -> None:
247+
"""Read client messages and dispatch to ping/interrupt/message handlers."""
248+
while True:
249+
try:
250+
raw = await websocket.receive_text()
251+
except WebSocketDisconnect:
252+
raise
253+
254+
try:
255+
msg = json.loads(raw)
256+
except json.JSONDecodeError:
257+
try:
258+
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
259+
except Exception:
260+
pass
261+
continue
262+
263+
msg_type = msg.get("type")
264+
265+
if msg_type == "ping":
266+
await websocket.send_json({"type": "pong"})
267+
268+
elif msg_type == "interrupt":
269+
await session_chat_manager.signal_interrupt(session_id)
270+
271+
elif msg_type == "message":
272+
content = msg.get("content", "")
273+
274+
# Cancel any in-flight adapter
275+
if adapter_task[0] and not adapter_task[0].done():
276+
adapter_task[0].cancel()
277+
try:
278+
await adapter_task[0]
279+
except (asyncio.CancelledError, Exception) as exc:
280+
logger.debug(
281+
"session_id=%s adapter cancelled: %s", session_id, exc
282+
)
283+
284+
# Reset interrupt and drain stale queue items
285+
await session_chat_manager.reset_interrupt(session_id)
286+
while not token_queue.empty():
287+
try:
288+
token_queue.get_nowait()
289+
except asyncio.QueueEmpty:
290+
break
291+
292+
interrupt_event = await session_chat_manager.get_interrupt_event(session_id)
293+
adapter_task[0] = asyncio.create_task(
294+
_run_streaming_adapter(session_id, content, token_queue, interrupt_event)
295+
)
296+
297+
relay = asyncio.create_task(_relay())
298+
await _receive()
299+
300+
except WebSocketDisconnect:
301+
logger.debug("Session chat WebSocket disconnected: session_id=%s", session_id)
302+
except Exception as exc:
303+
logger.error("Session chat WebSocket error: %s", exc, exc_info=True)
304+
try:
305+
await websocket.send_json({"type": "error", "message": str(exc)})
306+
except Exception:
307+
pass
308+
finally:
309+
if relay is not None:
310+
relay.cancel()
311+
try:
312+
await relay
313+
except (asyncio.CancelledError, Exception):
314+
pass
315+
if adapter_task[0] and not adapter_task[0].done():
316+
adapter_task[0].cancel()
317+
try:
318+
await adapter_task[0]
319+
except (asyncio.CancelledError, Exception) as exc:
320+
logger.debug("session_id=%s adapter cleanup: %s", session_id, exc)
321+
await session_chat_manager.unregister(session_id, websocket)
322+
try:
323+
await websocket.close()
324+
except Exception:
325+
pass

codeframe/ui/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
proof_v2,
3535
review_v2,
3636
schedule_v2,
37+
session_chat_ws,
3738
streaming_v2,
3839
tasks_v2,
3940
templates_v2,
@@ -487,6 +488,7 @@ async def test_broadcast(message: dict, project_id: int = None):
487488
app.include_router(gates_v2.router) # /api/v2/gates
488489
app.include_router(git_v2.router) # /api/v2/git
489490
app.include_router(interactive_sessions_v2.router) # /api/v2/sessions
491+
app.include_router(session_chat_ws.router) # /ws/sessions/{id}/chat
490492
app.include_router(pr_v2.router) # /api/v2/pr
491493
app.include_router(prd_v2.router) # /api/v2/prd
492494
app.include_router(proof_v2.router) # /api/v2/proof

0 commit comments

Comments
 (0)