Skip to content

Commit 511d05e

Browse files
committed
fix: use run_coroutine_threadsafe in send_message_sync to prevent cross-loop crash
WebSocketManager.send_message_sync is called from background worker threads (via asyncio.get_event_loop().run_in_executor) during workflow execution — by WebSocketLogger, ArtifactDispatcher, and WebPromptChannel. Previous implementation: try: loop = asyncio.get_running_loop() if loop.is_running(): asyncio.create_task(...) # path only reachable from main thread else: asyncio.run(...) # creates a NEW event loop except RuntimeError: asyncio.run(...) # also creates a new event loop The problem: WebSocket objects are bound to the *main* uvicorn event loop. asyncio.run() spins up a separate event loop and calls websocket.send_text() there, which in Python 3.12 raises: RuntimeError: Task got Future attached to a different loop ...causing all log/artifact/prompt messages emitted from workflow threads to be silently dropped or to crash the worker thread. Fix: - Store the event loop that created the first WebSocket connection as self._owner_loop (captured in connect(), which always runs on the main loop). - send_message_sync schedules the coroutine on that loop via asyncio.run_coroutine_threadsafe(), then waits with a 10 s timeout. - Calling from the main thread still works (run_coroutine_threadsafe is safe when called from any thread, including the loop thread itself). Added 7 tests covering: - send from main thread - send from worker thread (verifies send_text runs on the owner loop thread) - 8 concurrent workers with no lost messages - send after disconnect does not crash - send before connect (no owner loop) does not crash - owner loop captured on first connect - owner loop stable across multiple connects
1 parent cb75e06 commit 511d05e

File tree

2 files changed

+267
-7
lines changed

2 files changed

+267
-7
lines changed

server/services/websocket_manager.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
):
5050
self.active_connections: Dict[str, WebSocket] = {}
5151
self.connection_timestamps: Dict[str, float] = {}
52+
self._owner_loop: Optional[asyncio.AbstractEventLoop] = None
5253
self.session_store = session_store or WorkflowSessionStore()
5354
self.session_controller = session_controller or SessionExecutionController(self.session_store)
5455
self.attachment_service = attachment_service or AttachmentService()
@@ -65,6 +66,10 @@ def __init__(
6566

6667
async def connect(self, websocket: WebSocket, session_id: Optional[str] = None) -> str:
6768
await websocket.accept()
69+
# Capture the event loop that owns the WebSocket connections so that
70+
# worker threads can safely schedule sends via run_coroutine_threadsafe.
71+
if self._owner_loop is None:
72+
self._owner_loop = asyncio.get_running_loop()
6873
if not session_id:
6974
session_id = str(uuid.uuid4())
7075
self.active_connections[session_id] = websocket
@@ -108,14 +113,38 @@ async def send_message(self, session_id: str, message: Dict[str, Any]) -> None:
108113
# self.disconnect(session_id)
109114

110115
def send_message_sync(self, session_id: str, message: Dict[str, Any]) -> None:
116+
"""Send a WebSocket message from any thread (including worker threads).
117+
118+
WebSocket objects are bound to the event loop that created them (the main
119+
uvicorn loop). Previous code called ``asyncio.run()`` from worker threads
120+
which spins up a *new* event loop, causing ``RuntimeError: … attached to a
121+
different loop`` or silent delivery failures.
122+
123+
The fix: always schedule the coroutine on the loop that owns the sockets
124+
via ``asyncio.run_coroutine_threadsafe`` and wait for the result with a
125+
short timeout so the caller knows if delivery failed.
126+
"""
127+
loop = self._owner_loop
128+
if loop is None or loop.is_closed():
129+
logging.warning(
130+
"Cannot send sync message to %s: owner event loop unavailable",
131+
session_id,
132+
)
133+
return
134+
135+
future = asyncio.run_coroutine_threadsafe(
136+
self.send_message(session_id, message), loop
137+
)
111138
try:
112-
loop = asyncio.get_running_loop()
113-
if loop.is_running():
114-
asyncio.create_task(self.send_message(session_id, message))
115-
else:
116-
asyncio.run(self.send_message(session_id, message))
117-
except RuntimeError:
118-
asyncio.run(self.send_message(session_id, message))
139+
future.result(timeout=10)
140+
except TimeoutError:
141+
logging.warning(
142+
"Timed out sending sync WS message to %s", session_id
143+
)
144+
except Exception as exc:
145+
logging.error(
146+
"Error sending sync WS message to %s: %s", session_id, exc
147+
)
119148

120149
async def broadcast(self, message: Dict[str, Any]) -> None:
121150
for session_id in list(self.active_connections.keys()):
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
"""Tests for WebSocketManager.send_message_sync cross-thread safety.
2+
3+
Verifies that send_message_sync correctly delivers messages when called
4+
from worker threads (the common case during workflow execution).
5+
6+
The test avoids importing the full server stack (which has circular import
7+
issues) by patching only the WebSocketManager class directly.
8+
"""
9+
10+
import asyncio
11+
import concurrent.futures
12+
import json
13+
import sys
14+
import threading
15+
import time
16+
from typing import List
17+
from unittest.mock import MagicMock
18+
19+
import pytest
20+
21+
22+
# ---------------------------------------------------------------------------
23+
# Isolate WebSocketManager from the circular-import chain
24+
# ---------------------------------------------------------------------------
25+
26+
# Stub out heavy modules so we can import websocket_manager in isolation
27+
_stubs = {}
28+
for mod_name in (
29+
"check", "check.check",
30+
"runtime", "runtime.sdk", "runtime.bootstrap", "runtime.bootstrap.schema",
31+
"server.services.workflow_run_service",
32+
"server.services.message_handler",
33+
"server.services.attachment_service",
34+
"server.services.session_execution",
35+
"server.services.session_store",
36+
"server.services.artifact_events",
37+
):
38+
if mod_name not in sys.modules:
39+
_stubs[mod_name] = MagicMock()
40+
sys.modules[mod_name] = _stubs[mod_name]
41+
42+
from server.services.websocket_manager import WebSocketManager # noqa: E402
43+
44+
45+
# ---------------------------------------------------------------------------
46+
# Helpers
47+
# ---------------------------------------------------------------------------
48+
49+
def _make_manager() -> WebSocketManager:
50+
"""Create a WebSocketManager with minimal mocks."""
51+
return WebSocketManager(
52+
session_store=MagicMock(),
53+
session_controller=MagicMock(),
54+
attachment_service=MagicMock(),
55+
workflow_run_service=MagicMock(),
56+
)
57+
58+
59+
class FakeWebSocket:
60+
"""Lightweight fake that records sent messages and the thread they arrived on."""
61+
62+
def __init__(self) -> None:
63+
self.sent: List[str] = []
64+
self.send_threads: List[int] = []
65+
66+
async def accept(self) -> None:
67+
pass
68+
69+
async def send_text(self, data: str) -> None:
70+
self.sent.append(data)
71+
self.send_threads.append(threading.get_ident())
72+
73+
74+
# ---------------------------------------------------------------------------
75+
# Tests
76+
# ---------------------------------------------------------------------------
77+
78+
class TestSendMessageSync:
79+
"""send_message_sync must deliver messages regardless of calling thread."""
80+
81+
def test_send_from_main_thread(self):
82+
"""Message sent from the main (event-loop) thread is delivered."""
83+
manager = _make_manager()
84+
ws = FakeWebSocket()
85+
delivered = []
86+
87+
async def run():
88+
sid = await manager.connect(ws, session_id="s1")
89+
# Drain the initial "connection" message
90+
ws.sent.clear()
91+
92+
manager.send_message_sync(sid, {"type": "test", "data": "hello"})
93+
# Give the scheduled coroutine a moment to execute
94+
await asyncio.sleep(0.05)
95+
delivered.extend(ws.sent)
96+
97+
asyncio.run(run())
98+
assert len(delivered) == 1
99+
assert '"test"' in delivered[0]
100+
101+
def test_send_from_worker_thread(self):
102+
"""Message sent from a background (worker) thread is delivered on the owner loop."""
103+
manager = _make_manager()
104+
ws = FakeWebSocket()
105+
worker_errors: List[Exception] = []
106+
107+
async def run():
108+
sid = await manager.connect(ws, session_id="s2")
109+
ws.sent.clear()
110+
main_thread = threading.get_ident()
111+
112+
def worker():
113+
try:
114+
manager.send_message_sync(sid, {"type": "from_worker"})
115+
except Exception as exc:
116+
worker_errors.append(exc)
117+
118+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
119+
future = pool.submit(worker)
120+
# Let the worker thread finish and the scheduled coro run
121+
while not future.done():
122+
await asyncio.sleep(0.01)
123+
future.result() # re-raise if worker threw
124+
await asyncio.sleep(0.1)
125+
126+
# Verify delivery
127+
assert len(ws.sent) == 1, f"Expected 1 message, got {len(ws.sent)}"
128+
assert '"from_worker"' in ws.sent[0]
129+
130+
# Verify send_text ran on the main loop thread, not the worker
131+
assert ws.send_threads[0] == main_thread
132+
133+
asyncio.run(run())
134+
assert not worker_errors, f"Worker thread raised: {worker_errors}"
135+
136+
def test_concurrent_workers_no_lost_messages(self):
137+
"""Multiple concurrent workers should each have their message delivered.
138+
139+
In production, the event loop is free while workers run (the main coroutine
140+
awaits ``run_in_executor``). We replicate that by polling workers via
141+
``asyncio.sleep`` so the loop can process the scheduled sends.
142+
"""
143+
manager = _make_manager()
144+
ws = FakeWebSocket()
145+
num_workers = 8
146+
147+
async def run():
148+
sid = await manager.connect(ws, session_id="s3")
149+
ws.sent.clear()
150+
151+
barrier = threading.Barrier(num_workers)
152+
done_count = threading.atomic(0) if hasattr(threading, "atomic") else None
153+
done_flags = [False] * num_workers
154+
155+
def worker(idx: int):
156+
barrier.wait(timeout=5)
157+
manager.send_message_sync(sid, {"type": "msg", "idx": idx})
158+
done_flags[idx] = True
159+
160+
pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
161+
futures = [pool.submit(worker, i) for i in range(num_workers)]
162+
163+
# Yield control so the loop can process sends while workers run
164+
deadline = time.time() + 15
165+
while not all(done_flags) and time.time() < deadline:
166+
await asyncio.sleep(0.05)
167+
168+
# Collect any worker exceptions
169+
for f in futures:
170+
f.result(timeout=1)
171+
172+
# Let remaining coros drain
173+
await asyncio.sleep(0.3)
174+
pool.shutdown(wait=False)
175+
176+
assert len(ws.sent) == num_workers, (
177+
f"Expected {num_workers} messages, got {len(ws.sent)}"
178+
)
179+
180+
asyncio.run(run())
181+
182+
def test_send_after_disconnect_does_not_crash(self):
183+
"""Sending after disconnection should not raise."""
184+
manager = _make_manager()
185+
ws = FakeWebSocket()
186+
187+
async def run():
188+
sid = await manager.connect(ws, session_id="s4")
189+
manager.disconnect(sid)
190+
191+
# Should silently skip, not crash
192+
manager.send_message_sync(sid, {"type": "late"})
193+
await asyncio.sleep(0.05)
194+
195+
asyncio.run(run()) # no exception == pass
196+
197+
def test_send_before_any_connection_no_crash(self):
198+
"""Calling send_message_sync before any connect() should not crash."""
199+
manager = _make_manager()
200+
# _owner_loop is None
201+
manager.send_message_sync("nonexistent", {"type": "orphan"})
202+
# Should log a warning, not crash
203+
204+
205+
class TestOwnerLoopCapture:
206+
"""The manager must capture the event loop on first connect."""
207+
208+
def test_owner_loop_captured_on_connect(self):
209+
manager = _make_manager()
210+
ws = FakeWebSocket()
211+
212+
async def run():
213+
assert manager._owner_loop is None
214+
await manager.connect(ws, session_id="cap1")
215+
assert manager._owner_loop is asyncio.get_running_loop()
216+
217+
asyncio.run(run())
218+
219+
def test_owner_loop_stable_across_connections(self):
220+
"""Subsequent connects should not reset the owner loop."""
221+
manager = _make_manager()
222+
ws1 = FakeWebSocket()
223+
ws2 = FakeWebSocket()
224+
225+
async def run():
226+
await manager.connect(ws1, session_id="cap2")
227+
loop1 = manager._owner_loop
228+
await manager.connect(ws2, session_id="cap3")
229+
assert manager._owner_loop is loop1
230+
231+
asyncio.run(run())

0 commit comments

Comments
 (0)