Skip to content

Commit b89b0da

Browse files
committed
ws client hardening
1 parent d372fac commit b89b0da

2 files changed

Lines changed: 195 additions & 15 deletions

File tree

src/mistapi/websockets/__ws_client.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
) # tracks whether the WebSocket connection is currently open
7171
self._user_disconnect = threading.Event()
7272
self._finished = threading.Event()
73+
self._finished.set() # not running initially
7374
self._reconnect_attempts = 0
7475
self._last_close_code: int | None = None
7576
self._last_close_msg: str | None = None
@@ -164,7 +165,6 @@ def _handle_open(self, ws: websocket.WebSocketApp) -> None:
164165
ws.send(json.dumps({"subscribe": channel}))
165166
except Exception as exc:
166167
logger.error("Subscription send failed: %s", exc)
167-
self._handle_error(ws, exc)
168168
ws.close()
169169
return
170170
self._reconnect_attempts = 0
@@ -235,12 +235,10 @@ def connect(self, run_in_background: bool = True) -> None:
235235
If False, blocks the calling thread until disconnected.
236236
"""
237237
with self._lock:
238-
if self._connected.is_set() or (
239-
self._thread is not None and self._thread.is_alive()
240-
):
238+
if self._connected.is_set() or not self._finished.is_set():
241239
raise RuntimeError("Already connected; call disconnect() first")
242-
self._user_disconnect.clear()
243240
self._finished.clear()
241+
self._user_disconnect.clear()
244242
self._reconnect_attempts = 0
245243
# Drain stale sentinel from previous connection
246244
while not self._queue.empty():
@@ -309,15 +307,14 @@ def _run_forever_safe(self) -> None:
309307
except Exception:
310308
pass
311309

312-
# Final close: put sentinel, call callback
313-
self._queue.put(None)
310+
finally:
311+
self._queue.put(None) # sentinel — unblocks receive()
312+
self._finished.set() # mark as not running — unblocks connect()
314313
if self._on_close_cb:
315314
try:
316315
self._on_close_cb(self._last_close_code, self._last_close_msg)
317316
except Exception:
318317
logger.exception("on_close callback raised")
319-
finally:
320-
self._finished.set()
321318

322319
def disconnect(self, wait: bool = False, timeout: float | None = None) -> None:
323320
"""Close the WebSocket connection.
@@ -336,7 +333,8 @@ def disconnect(self, wait: bool = False, timeout: float | None = None) -> None:
336333
if ws:
337334
ws.close()
338335
if wait and self._thread is not None:
339-
self._thread.join(timeout=timeout)
336+
if self._thread is not threading.current_thread():
337+
self._thread.join(timeout=timeout)
340338

341339
def receive(self) -> Generator[dict, None, None]:
342340
"""
@@ -346,7 +344,13 @@ def receive(self) -> Generator[dict, None, None]:
346344
the server closes the connection).
347345
348346
Intended for use after connect(run_in_background=True).
347+
Cannot be used when an on_message callback is registered.
349348
"""
349+
if self._on_message_cb is not None:
350+
raise RuntimeError(
351+
"receive() cannot be used when an on_message callback is "
352+
"registered; use one or the other"
353+
)
350354
if self._auto_reconnect:
351355
while (
352356
not self._connected.is_set()
@@ -357,11 +361,15 @@ def receive(self) -> Generator[dict, None, None]:
357361
if not self._connected.is_set():
358362
return
359363
elif not self._connected.wait(timeout=10):
360-
return
364+
if not self._finished.is_set():
365+
return
366+
# Thread already finished — fall through to drain queued messages
361367
while True:
362368
try:
363369
item = self._queue.get(timeout=1)
364370
except queue.Empty:
371+
if self._finished.is_set() and self._queue.empty():
372+
break
365373
if not self._connected.is_set() and self._queue.empty():
366374
if (
367375
self._auto_reconnect

tests/unit/test_websocket_client.py

Lines changed: 176 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,21 +1026,40 @@ def bad_cb(code, msg):
10261026
# Sentinel should still be in the queue
10271027
assert ws_client._queue.get_nowait() is None
10281028

1029+
@patch("mistapi.websockets.__ws_client.websocket.WebSocketApp")
1030+
def test_connect_from_on_close_callback(self, mock_ws_cls, mock_session) -> None:
1031+
"""connect() must work when called from inside the on_close callback."""
1032+
mock_ws_cls.return_value = Mock()
1033+
client = _MistWebsocket(mock_session, channels=["/ch"])
1034+
errors = []
1035+
1036+
def on_close_cb(code, msg):
1037+
try:
1038+
client.connect(run_in_background=True)
1039+
except RuntimeError as exc:
1040+
errors.append(exc)
1041+
finally:
1042+
client.disconnect()
1043+
1044+
client.on_close(on_close_cb)
1045+
mock_ws = Mock()
1046+
client._ws = mock_ws
1047+
client._finished.clear() # simulate connect() was called
1048+
client._run_forever_safe()
1049+
1050+
assert not errors, f"on_close callback raised: {errors}"
1051+
10291052
def test_on_open_send_failure_closes_connection(self, ws_client) -> None:
10301053
"""If ws.send() raises during subscription, connection is closed."""
10311054
mock_ws = Mock()
10321055
mock_ws.send.side_effect = ConnectionError("send failed")
1033-
error_cb = Mock()
1034-
ws_client.on_error(error_cb)
10351056

10361057
ws_client._handle_open(mock_ws)
10371058

10381059
# Connection should NOT be marked as connected
10391060
assert not ws_client._connected.is_set()
10401061
# ws.close() should have been called
10411062
mock_ws.close.assert_called_once()
1042-
# Error callback should have been invoked
1043-
error_cb.assert_called_once()
10441063

10451064

10461065
# ---------------------------------------------------------------------------
@@ -1064,6 +1083,12 @@ def test_no_callback_uses_queue(self, ws_client) -> None:
10641083
assert not ws_client._queue.empty()
10651084
assert ws_client._queue.get_nowait() == {"event": "data"}
10661085

1086+
def test_receive_raises_when_message_callback_registered(self, ws_client) -> None:
1087+
ws_client.on_message(Mock())
1088+
ws_client._connected.set()
1089+
with pytest.raises(RuntimeError, match="on_message callback"):
1090+
list(ws_client.receive())
1091+
10671092

10681093
# ---------------------------------------------------------------------------
10691094
# disconnect(wait=...)
@@ -1084,6 +1109,26 @@ def test_disconnect_wait_blocks_until_thread_finishes(
10841109
assert client._finished.is_set()
10851110
assert not client._thread.is_alive()
10861111

1112+
def test_disconnect_wait_from_callback_does_not_self_join(
1113+
self, mock_session
1114+
) -> None:
1115+
"""disconnect(wait=True) from inside a callback must not raise."""
1116+
client = _MistWebsocket(mock_session, channels=["/ch"])
1117+
error_from_cb = []
1118+
1119+
def on_close_cb(code, msg):
1120+
try:
1121+
client.disconnect(wait=True, timeout=1)
1122+
except Exception as exc:
1123+
error_from_cb.append(exc)
1124+
1125+
client.on_close(on_close_cb)
1126+
mock_ws = Mock()
1127+
client._ws = mock_ws
1128+
client._run_forever_safe()
1129+
1130+
assert not error_from_cb, f"Callback raised: {error_from_cb}"
1131+
10871132

10881133
# ---------------------------------------------------------------------------
10891134
# Cookie edge cases
@@ -1131,3 +1176,130 @@ def test_session_with_url_rejects_http(self, mock_session) -> None:
11311176
def test_session_with_url_accepts_wss(self, mock_session) -> None:
11321177
ws = SessionWithUrl(mock_session, url="wss://api-ws.mist.com/stream")
11331178
assert ws._build_ws_url() == "wss://api-ws.mist.com/stream"
1179+
1180+
1181+
# ---------------------------------------------------------------------------
1182+
# Connect / disconnect / connect cycle
1183+
# ---------------------------------------------------------------------------
1184+
1185+
1186+
class TestConnectDisconnectCycle:
1187+
"""Verify that connect → disconnect → connect works cleanly."""
1188+
1189+
@patch("mistapi.websockets.__ws_client.websocket.WebSocketApp")
1190+
def test_reconnect_after_disconnect(self, mock_ws_cls, mock_session) -> None:
1191+
mock_ws_cls.return_value = Mock()
1192+
client = _MistWebsocket(mock_session, channels=["/ch"])
1193+
1194+
# First cycle
1195+
client.connect(run_in_background=True)
1196+
client.disconnect(wait=True, timeout=5)
1197+
assert client._finished.is_set()
1198+
1199+
# Second cycle — should not raise
1200+
client.connect(run_in_background=True)
1201+
client.disconnect(wait=True, timeout=5)
1202+
assert client._finished.is_set()
1203+
1204+
1205+
# ---------------------------------------------------------------------------
1206+
# receive() exits when thread dies (no sentinel scenario)
1207+
# ---------------------------------------------------------------------------
1208+
1209+
1210+
class TestReceiveFinishedExit:
1211+
"""Verify receive() exits when _finished is set even without a sentinel."""
1212+
1213+
def test_receive_exits_when_finished_set_without_sentinel(
1214+
self, ws_client
1215+
) -> None:
1216+
"""Simulates a BaseException scenario where sentinel is never queued."""
1217+
ws_client._connected.set()
1218+
# Simulate: thread died, _finished set, _connected still set, no sentinel
1219+
ws_client._finished.set()
1220+
# receive() should exit promptly
1221+
results = list(ws_client.receive())
1222+
assert results == []
1223+
1224+
def test_receive_exits_when_finished_set_with_connected_cleared(
1225+
self, ws_client
1226+
) -> None:
1227+
ws_client._connected.set()
1228+
ws_client._connected.clear()
1229+
ws_client._finished.set()
1230+
results = list(ws_client.receive())
1231+
assert results == []
1232+
1233+
def test_receive_drains_queue_when_connection_closed_before_receive(
1234+
self, ws_client
1235+
) -> None:
1236+
"""If connection opens, messages arrive, and connection closes before
1237+
receive() is called, the queued messages should still be yielded."""
1238+
ws_client._finished.clear() # simulate connect() was called
1239+
# Simulate: connection opened, messages arrived, connection closed
1240+
ws_client._queue.put({"event": "a"})
1241+
ws_client._queue.put({"event": "b"})
1242+
ws_client._queue.put(None) # sentinel from _run_forever_safe
1243+
ws_client._finished.set() # thread finished
1244+
# _connected was set then cleared — currently unset
1245+
1246+
results = list(ws_client.receive())
1247+
assert results == [{"event": "a"}, {"event": "b"}]
1248+
1249+
1250+
# ---------------------------------------------------------------------------
1251+
# Blocking connect guard
1252+
# ---------------------------------------------------------------------------
1253+
1254+
1255+
class TestBlockingConnectGuard:
1256+
"""Verify that the guard prevents double-connect in blocking mode."""
1257+
1258+
def test_connect_blocking_sets_finished_cleared(self, mock_session) -> None:
1259+
"""_finished is cleared inside connect(), preventing concurrent connect."""
1260+
client = _MistWebsocket(mock_session, channels=["/ch"])
1261+
assert client._finished.is_set() # starts set = ready
1262+
1263+
mock_ws = Mock()
1264+
with patch(
1265+
"mistapi.websockets.__ws_client.websocket.WebSocketApp",
1266+
return_value=mock_ws,
1267+
):
1268+
client.connect(run_in_background=False)
1269+
1270+
# After _run_forever_safe returns, _finished is set again
1271+
assert client._finished.is_set()
1272+
1273+
def test_double_connect_raises_while_blocking(self, mock_session) -> None:
1274+
"""If a blocking connect is in progress, a concurrent connect raises."""
1275+
client = _MistWebsocket(mock_session, channels=["/ch"])
1276+
barrier = threading.Event()
1277+
1278+
def blocking_run_forever(**kwargs):
1279+
barrier.wait(timeout=5) # block until test releases
1280+
1281+
mock_ws = Mock()
1282+
mock_ws.run_forever.side_effect = blocking_run_forever
1283+
1284+
with patch(
1285+
"mistapi.websockets.__ws_client.websocket.WebSocketApp",
1286+
return_value=mock_ws,
1287+
):
1288+
t = threading.Thread(
1289+
target=client.connect, kwargs={"run_in_background": False}
1290+
)
1291+
t.start()
1292+
1293+
# Wait for _finished to be cleared (connect entered)
1294+
for _ in range(50):
1295+
if not client._finished.is_set():
1296+
break
1297+
threading.Event().wait(timeout=0.05)
1298+
1299+
# Second connect should raise
1300+
with pytest.raises(RuntimeError, match="Already connected"):
1301+
client.connect(run_in_background=True)
1302+
1303+
barrier.set() # release blocking thread
1304+
t.join(timeout=5)
1305+
assert not t.is_alive()

0 commit comments

Comments
 (0)