|
| 1 | +"""Tests for WebSocketManager.send_message_sync cross-thread safety. |
| 2 | +
|
| 3 | +Verifies that send_message_sync correctly delivers messages when called |
| 4 | +from worker threads (the common case during workflow execution). |
| 5 | +
|
| 6 | +The test avoids importing the full server stack (which has circular import |
| 7 | +issues) by patching only the WebSocketManager class directly. |
| 8 | +""" |
| 9 | + |
| 10 | +import asyncio |
| 11 | +import concurrent.futures |
| 12 | +import json |
| 13 | +import sys |
| 14 | +import threading |
| 15 | +import time |
| 16 | +from typing import List |
| 17 | +from unittest.mock import MagicMock |
| 18 | + |
| 19 | +import pytest |
| 20 | + |
| 21 | + |
| 22 | +# --------------------------------------------------------------------------- |
| 23 | +# Isolate WebSocketManager from the circular-import chain |
| 24 | +# --------------------------------------------------------------------------- |
| 25 | + |
| 26 | +# Stub out heavy modules so we can import websocket_manager in isolation |
| 27 | +_stubs = {} |
| 28 | +for mod_name in ( |
| 29 | + "check", "check.check", |
| 30 | + "runtime", "runtime.sdk", "runtime.bootstrap", "runtime.bootstrap.schema", |
| 31 | + "server.services.workflow_run_service", |
| 32 | + "server.services.message_handler", |
| 33 | + "server.services.attachment_service", |
| 34 | + "server.services.session_execution", |
| 35 | + "server.services.session_store", |
| 36 | + "server.services.artifact_events", |
| 37 | +): |
| 38 | + if mod_name not in sys.modules: |
| 39 | + _stubs[mod_name] = MagicMock() |
| 40 | + sys.modules[mod_name] = _stubs[mod_name] |
| 41 | + |
| 42 | +from server.services.websocket_manager import WebSocketManager # noqa: E402 |
| 43 | + |
| 44 | + |
| 45 | +# --------------------------------------------------------------------------- |
| 46 | +# Helpers |
| 47 | +# --------------------------------------------------------------------------- |
| 48 | + |
| 49 | +def _make_manager() -> WebSocketManager: |
| 50 | + """Create a WebSocketManager with minimal mocks.""" |
| 51 | + return WebSocketManager( |
| 52 | + session_store=MagicMock(), |
| 53 | + session_controller=MagicMock(), |
| 54 | + attachment_service=MagicMock(), |
| 55 | + workflow_run_service=MagicMock(), |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +class FakeWebSocket: |
| 60 | + """Lightweight fake that records sent messages and the thread they arrived on.""" |
| 61 | + |
| 62 | + def __init__(self) -> None: |
| 63 | + self.sent: List[str] = [] |
| 64 | + self.send_threads: List[int] = [] |
| 65 | + |
| 66 | + async def accept(self) -> None: |
| 67 | + pass |
| 68 | + |
| 69 | + async def send_text(self, data: str) -> None: |
| 70 | + self.sent.append(data) |
| 71 | + self.send_threads.append(threading.get_ident()) |
| 72 | + |
| 73 | + |
| 74 | +# --------------------------------------------------------------------------- |
| 75 | +# Tests |
| 76 | +# --------------------------------------------------------------------------- |
| 77 | + |
| 78 | +class TestSendMessageSync: |
| 79 | + """send_message_sync must deliver messages regardless of calling thread.""" |
| 80 | + |
| 81 | + def test_send_from_main_thread(self): |
| 82 | + """Message sent from the main (event-loop) thread is delivered.""" |
| 83 | + manager = _make_manager() |
| 84 | + ws = FakeWebSocket() |
| 85 | + delivered = [] |
| 86 | + |
| 87 | + async def run(): |
| 88 | + sid = await manager.connect(ws, session_id="s1") |
| 89 | + # Drain the initial "connection" message |
| 90 | + ws.sent.clear() |
| 91 | + |
| 92 | + manager.send_message_sync(sid, {"type": "test", "data": "hello"}) |
| 93 | + # Give the scheduled coroutine a moment to execute |
| 94 | + await asyncio.sleep(0.05) |
| 95 | + delivered.extend(ws.sent) |
| 96 | + |
| 97 | + asyncio.run(run()) |
| 98 | + assert len(delivered) == 1 |
| 99 | + assert '"test"' in delivered[0] |
| 100 | + |
| 101 | + def test_send_from_worker_thread(self): |
| 102 | + """Message sent from a background (worker) thread is delivered on the owner loop.""" |
| 103 | + manager = _make_manager() |
| 104 | + ws = FakeWebSocket() |
| 105 | + worker_errors: List[Exception] = [] |
| 106 | + |
| 107 | + async def run(): |
| 108 | + sid = await manager.connect(ws, session_id="s2") |
| 109 | + ws.sent.clear() |
| 110 | + main_thread = threading.get_ident() |
| 111 | + |
| 112 | + def worker(): |
| 113 | + try: |
| 114 | + manager.send_message_sync(sid, {"type": "from_worker"}) |
| 115 | + except Exception as exc: |
| 116 | + worker_errors.append(exc) |
| 117 | + |
| 118 | + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: |
| 119 | + future = pool.submit(worker) |
| 120 | + # Let the worker thread finish and the scheduled coro run |
| 121 | + while not future.done(): |
| 122 | + await asyncio.sleep(0.01) |
| 123 | + future.result() # re-raise if worker threw |
| 124 | + await asyncio.sleep(0.1) |
| 125 | + |
| 126 | + # Verify delivery |
| 127 | + assert len(ws.sent) == 1, f"Expected 1 message, got {len(ws.sent)}" |
| 128 | + assert '"from_worker"' in ws.sent[0] |
| 129 | + |
| 130 | + # Verify send_text ran on the main loop thread, not the worker |
| 131 | + assert ws.send_threads[0] == main_thread |
| 132 | + |
| 133 | + asyncio.run(run()) |
| 134 | + assert not worker_errors, f"Worker thread raised: {worker_errors}" |
| 135 | + |
| 136 | + def test_concurrent_workers_no_lost_messages(self): |
| 137 | + """Multiple concurrent workers should each have their message delivered. |
| 138 | +
|
| 139 | + In production, the event loop is free while workers run (the main coroutine |
| 140 | + awaits ``run_in_executor``). We replicate that by polling workers via |
| 141 | + ``asyncio.sleep`` so the loop can process the scheduled sends. |
| 142 | + """ |
| 143 | + manager = _make_manager() |
| 144 | + ws = FakeWebSocket() |
| 145 | + num_workers = 8 |
| 146 | + |
| 147 | + async def run(): |
| 148 | + sid = await manager.connect(ws, session_id="s3") |
| 149 | + ws.sent.clear() |
| 150 | + |
| 151 | + barrier = threading.Barrier(num_workers) |
| 152 | + done_count = threading.atomic(0) if hasattr(threading, "atomic") else None |
| 153 | + done_flags = [False] * num_workers |
| 154 | + |
| 155 | + def worker(idx: int): |
| 156 | + barrier.wait(timeout=5) |
| 157 | + manager.send_message_sync(sid, {"type": "msg", "idx": idx}) |
| 158 | + done_flags[idx] = True |
| 159 | + |
| 160 | + pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) |
| 161 | + futures = [pool.submit(worker, i) for i in range(num_workers)] |
| 162 | + |
| 163 | + # Yield control so the loop can process sends while workers run |
| 164 | + deadline = time.time() + 15 |
| 165 | + while not all(done_flags) and time.time() < deadline: |
| 166 | + await asyncio.sleep(0.05) |
| 167 | + |
| 168 | + # Collect any worker exceptions |
| 169 | + for f in futures: |
| 170 | + f.result(timeout=1) |
| 171 | + |
| 172 | + # Let remaining coros drain |
| 173 | + await asyncio.sleep(0.3) |
| 174 | + pool.shutdown(wait=False) |
| 175 | + |
| 176 | + assert len(ws.sent) == num_workers, ( |
| 177 | + f"Expected {num_workers} messages, got {len(ws.sent)}" |
| 178 | + ) |
| 179 | + |
| 180 | + asyncio.run(run()) |
| 181 | + |
| 182 | + def test_send_after_disconnect_does_not_crash(self): |
| 183 | + """Sending after disconnection should not raise.""" |
| 184 | + manager = _make_manager() |
| 185 | + ws = FakeWebSocket() |
| 186 | + |
| 187 | + async def run(): |
| 188 | + sid = await manager.connect(ws, session_id="s4") |
| 189 | + manager.disconnect(sid) |
| 190 | + |
| 191 | + # Should silently skip, not crash |
| 192 | + manager.send_message_sync(sid, {"type": "late"}) |
| 193 | + await asyncio.sleep(0.05) |
| 194 | + |
| 195 | + asyncio.run(run()) # no exception == pass |
| 196 | + |
| 197 | + def test_send_before_any_connection_no_crash(self): |
| 198 | + """Calling send_message_sync before any connect() should not crash.""" |
| 199 | + manager = _make_manager() |
| 200 | + # _owner_loop is None |
| 201 | + manager.send_message_sync("nonexistent", {"type": "orphan"}) |
| 202 | + # Should log a warning, not crash |
| 203 | + |
| 204 | + |
| 205 | +class TestOwnerLoopCapture: |
| 206 | + """The manager must capture the event loop on first connect.""" |
| 207 | + |
| 208 | + def test_owner_loop_captured_on_connect(self): |
| 209 | + manager = _make_manager() |
| 210 | + ws = FakeWebSocket() |
| 211 | + |
| 212 | + async def run(): |
| 213 | + assert manager._owner_loop is None |
| 214 | + await manager.connect(ws, session_id="cap1") |
| 215 | + assert manager._owner_loop is asyncio.get_running_loop() |
| 216 | + |
| 217 | + asyncio.run(run()) |
| 218 | + |
| 219 | + def test_owner_loop_stable_across_connections(self): |
| 220 | + """Subsequent connects should not reset the owner loop.""" |
| 221 | + manager = _make_manager() |
| 222 | + ws1 = FakeWebSocket() |
| 223 | + ws2 = FakeWebSocket() |
| 224 | + |
| 225 | + async def run(): |
| 226 | + await manager.connect(ws1, session_id="cap2") |
| 227 | + loop1 = manager._owner_loop |
| 228 | + await manager.connect(ws2, session_id="cap3") |
| 229 | + assert manager._owner_loop is loop1 |
| 230 | + |
| 231 | + asyncio.run(run()) |
0 commit comments