Skip to content

Commit 0b63172

Browse files
authored
fix: propagate connection errors to subscription queues (#1681)
Closes #1678
1 parent 9ce7e68 commit 0b63172

3 files changed

Lines changed: 309 additions & 10 deletions

File tree

binance/ws/reconnecting_websocket.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
pass
1616

1717
try:
18-
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore
18+
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore
1919
except ImportError:
2020
from websockets import ConnectionClosedError, ConnectionClosedOK # type: ignore
2121

@@ -78,6 +78,10 @@ def __init__(
7878
self._ws_kwargs = kwargs
7979
self.max_queue_size = max_queue_size
8080

81+
async def _propagate_error(self, error_msg: dict):
82+
"""Put error message on the main queue. Subclasses can override to propagate elsewhere."""
83+
await self._queue.put(error_msg)
84+
8185
def json_dumps(self, msg) -> str:
8286
if orjson:
8387
return orjson.dumps(msg).decode("utf-8")
@@ -216,7 +220,7 @@ async def _read_loop(self):
216220
# _no_message_received_reconnect
217221
except asyncio.CancelledError as e:
218222
self._log.debug(f"_read_loop cancelled error {e}")
219-
await self._queue.put({
223+
await self._propagate_error({
220224
"e": "error",
221225
"type": f"{e.__class__.__name__}",
222226
"m": f"{e}",
@@ -231,7 +235,7 @@ async def _read_loop(self):
231235
) as e:
232236
# reports errors and continue loop
233237
self._log.error(f"{e.__class__.__name__} ({e})")
234-
await self._queue.put({
238+
await self._propagate_error({
235239
"e": "error",
236240
"type": f"{e.__class__.__name__}",
237241
"m": f"{e}",
@@ -243,7 +247,7 @@ async def _read_loop(self):
243247
) as e:
244248
# reports errors and break the loop
245249
self._log.error(f"Unknown exception: {e.__class__.__name__} ({e})")
246-
await self._queue.put({
250+
await self._propagate_error({
247251
"e": "error",
248252
"type": e.__class__.__name__,
249253
"m": f"{e}",

binance/ws/websocket_api.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,26 @@
99

1010

1111
class WebsocketAPI(ReconnectingWebsocket):
12-
def __init__(self, url: str, tld: str = "com", testnet: bool = False, https_proxy: Optional[str] = None):
12+
def __init__(
13+
self,
14+
url: str,
15+
tld: str = "com",
16+
testnet: bool = False,
17+
https_proxy: Optional[str] = None,
18+
):
1319
self._tld = tld
1420
self._testnet = testnet
1521
self._responses: Dict[str, asyncio.Future] = {}
1622
self._connection_lock: Optional[asyncio.Lock] = None
1723
# Subscription queues for routing user data stream events
1824
self._subscription_queues: Dict[str, asyncio.Queue] = {}
19-
super().__init__(url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy)
25+
super().__init__(
26+
url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy
27+
)
2028

21-
def register_subscription_queue(self, subscription_id: str, queue: asyncio.Queue) -> None:
29+
def register_subscription_queue(
30+
self, subscription_id: str, queue: asyncio.Queue
31+
) -> None:
2232
"""Register a queue to receive events for a specific subscription."""
2333
self._subscription_queues[subscription_id] = queue
2434

@@ -33,6 +43,15 @@ def connection_lock(self) -> asyncio.Lock:
3343
self._connection_lock = asyncio.Lock()
3444
return self._connection_lock
3545

46+
async def _propagate_error(self, error_msg: dict):
47+
"""Propagate error to main queue and all subscription queues."""
48+
await super()._propagate_error(error_msg)
49+
for queue in self._subscription_queues.values():
50+
try:
51+
queue.put_nowait(error_msg)
52+
except asyncio.QueueFull:
53+
self._log.error("Subscription queue full, dropping error message")
54+
3655
def _handle_message(self, msg):
3756
"""Override message handling to support request-response"""
3857
parsed_msg = super()._handle_message(msg)
@@ -51,9 +70,13 @@ def _handle_message(self, msg):
5170
try:
5271
queue.put_nowait(event)
5372
except asyncio.QueueFull:
54-
self._log.error(f"Subscription queue full for {subscription_id}, dropping event")
73+
self._log.error(
74+
f"Subscription queue full for {subscription_id}, dropping event"
75+
)
5576
except Exception as e:
56-
self._log.error(f"Error putting event in subscription queue for {subscription_id}: {e}")
77+
self._log.error(
78+
f"Error putting event in subscription queue for {subscription_id}: {e}"
79+
)
5780
return None # Don't put in main queue
5881
else:
5982
# No registered queue, return event for main queue (backward compat)
@@ -65,7 +88,9 @@ def _handle_message(self, msg):
6588
if "status" in parsed_msg:
6689
if parsed_msg["status"] != 200:
6790
exception = BinanceAPIException(
68-
parsed_msg, parsed_msg["status"], self.json_dumps(parsed_msg["error"])
91+
parsed_msg,
92+
parsed_msg["status"],
93+
self.json_dumps(parsed_msg["error"]),
6994
)
7095
if req_id is not None and req_id in self._responses:
7196
if exception is not None:

tests/test_error_propagation.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""
2+
Tests for issue #1678: Connection errors must be propagated to subscription queues.
3+
4+
Verifies that when _read_loop() catches a connection error, the error dict
5+
is delivered to all registered subscription queues (not just the main queue).
6+
"""
7+
8+
import sys
9+
import asyncio
10+
import pytest
11+
from unittest.mock import AsyncMock, PropertyMock
12+
13+
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
14+
import websockets.protocol as ws_protocol
15+
16+
from binance.ws.reconnecting_websocket import ReconnectingWebsocket
17+
from binance.ws.websocket_api import WebsocketAPI
18+
from binance.ws.constants import WSListenerState
19+
20+
21+
# -- Unit tests for _propagate_error --
22+
23+
24+
@pytest.mark.asyncio
25+
async def test_base_propagate_error_puts_on_main_queue():
26+
"""ReconnectingWebsocket._propagate_error should put on self._queue."""
27+
ws = ReconnectingWebsocket(url="wss://test.url")
28+
error_msg = {"e": "error", "type": "TestError", "m": "test"}
29+
30+
await ws._propagate_error(error_msg)
31+
32+
assert ws._queue.qsize() == 1
33+
assert await ws._queue.get() == error_msg
34+
35+
36+
@pytest.mark.asyncio
37+
async def test_websocket_api_propagate_error_puts_on_main_and_subscription_queues():
38+
"""WebsocketAPI._propagate_error should put on main queue AND all subscription queues."""
39+
ws = WebsocketAPI(url="wss://test.url")
40+
sub_queue_1 = asyncio.Queue()
41+
sub_queue_2 = asyncio.Queue()
42+
ws.register_subscription_queue("sub1", sub_queue_1)
43+
ws.register_subscription_queue("sub2", sub_queue_2)
44+
45+
error_msg = {"e": "error", "type": "TestError", "m": "test"}
46+
await ws._propagate_error(error_msg)
47+
48+
# Main queue gets the error
49+
assert ws._queue.qsize() == 1
50+
assert await ws._queue.get() == error_msg
51+
52+
# Both subscription queues get the error
53+
assert sub_queue_1.qsize() == 1
54+
assert await sub_queue_1.get() == error_msg
55+
assert sub_queue_2.qsize() == 1
56+
assert await sub_queue_2.get() == error_msg
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_propagate_error_handles_full_subscription_queue():
61+
"""Should not raise when a subscription queue is full."""
62+
ws = WebsocketAPI(url="wss://test.url")
63+
full_queue = asyncio.Queue(maxsize=1)
64+
full_queue.put_nowait({"existing": "msg"}) # Fill it up
65+
ws.register_subscription_queue("sub_full", full_queue)
66+
67+
error_msg = {"e": "error", "type": "TestError", "m": "test"}
68+
await ws._propagate_error(error_msg)
69+
70+
# Main queue still gets it
71+
assert ws._queue.qsize() == 1
72+
# Full subscription queue still has only the old message (error was dropped)
73+
assert full_queue.qsize() == 1
74+
assert (await full_queue.get())["existing"] == "msg"
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_propagate_error_with_no_subscriptions():
79+
"""Should work fine when no subscription queues are registered."""
80+
ws = WebsocketAPI(url="wss://test.url")
81+
82+
error_msg = {"e": "error", "type": "TestError", "m": "test"}
83+
await ws._propagate_error(error_msg)
84+
85+
assert ws._queue.qsize() == 1
86+
assert await ws._queue.get() == error_msg
87+
88+
89+
# -- Integration tests: _read_loop propagates errors to subscription queues --
90+
91+
92+
def _make_ws_api_with_mock(recv_side_effect):
93+
"""Helper: create a WebsocketAPI with a mocked websocket."""
94+
ws = WebsocketAPI(url="wss://test.url")
95+
mock_ws = AsyncMock()
96+
type(mock_ws).state = PropertyMock(return_value=ws_protocol.State.OPEN)
97+
mock_ws.recv = recv_side_effect
98+
mock_ws.close = AsyncMock()
99+
ws.ws = mock_ws
100+
ws.ws_state = WSListenerState.STREAMING
101+
return ws
102+
103+
104+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
105+
@pytest.mark.asyncio
106+
async def test_read_loop_connection_closed_error_reaches_subscription_queue():
107+
"""ConnectionClosedError in _read_loop should be delivered to subscription queues."""
108+
call_count = 0
109+
110+
async def recv_side_effect():
111+
nonlocal call_count
112+
call_count += 1
113+
if call_count == 1:
114+
raise ConnectionClosedError(None, None)
115+
raise asyncio.CancelledError()
116+
117+
ws = _make_ws_api_with_mock(recv_side_effect)
118+
sub_queue = asyncio.Queue()
119+
ws.register_subscription_queue("user_sub", sub_queue)
120+
121+
try:
122+
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
123+
except (asyncio.TimeoutError, asyncio.CancelledError):
124+
pass
125+
126+
assert sub_queue.qsize() >= 1, "Subscription queue should have received the error"
127+
msg = await sub_queue.get()
128+
assert msg["e"] == "error"
129+
assert msg["type"] == "ConnectionClosedError"
130+
131+
132+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
133+
@pytest.mark.asyncio
134+
async def test_read_loop_connection_closed_ok_reaches_subscription_queue():
135+
"""ConnectionClosedOK in _read_loop should be delivered to subscription queues."""
136+
call_count = 0
137+
138+
async def recv_side_effect():
139+
nonlocal call_count
140+
call_count += 1
141+
if call_count == 1:
142+
raise ConnectionClosedOK(None, None)
143+
raise asyncio.CancelledError()
144+
145+
ws = _make_ws_api_with_mock(recv_side_effect)
146+
sub_queue = asyncio.Queue()
147+
ws.register_subscription_queue("user_sub", sub_queue)
148+
149+
try:
150+
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
151+
except (asyncio.TimeoutError, asyncio.CancelledError):
152+
pass
153+
154+
assert sub_queue.qsize() >= 1
155+
msg = await sub_queue.get()
156+
assert msg["e"] == "error"
157+
assert msg["type"] == "ConnectionClosedOK"
158+
159+
160+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
161+
@pytest.mark.asyncio
162+
async def test_read_loop_cancelled_error_reaches_subscription_queue():
163+
"""CancelledError in _read_loop should be delivered to subscription queues."""
164+
165+
async def recv_side_effect():
166+
raise asyncio.CancelledError()
167+
168+
ws = _make_ws_api_with_mock(recv_side_effect)
169+
sub_queue = asyncio.Queue()
170+
ws.register_subscription_queue("user_sub", sub_queue)
171+
172+
try:
173+
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
174+
except (asyncio.TimeoutError, asyncio.CancelledError):
175+
pass
176+
177+
assert sub_queue.qsize() >= 1
178+
msg = await sub_queue.get()
179+
assert msg["e"] == "error"
180+
assert msg["type"] == "CancelledError"
181+
182+
183+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
184+
@pytest.mark.asyncio
185+
async def test_read_loop_fatal_error_reaches_subscription_queue():
186+
"""Generic exceptions in _read_loop should be delivered to subscription queues."""
187+
188+
async def recv_side_effect():
189+
raise RuntimeError("something broke")
190+
191+
ws = _make_ws_api_with_mock(recv_side_effect)
192+
sub_queue = asyncio.Queue()
193+
ws.register_subscription_queue("user_sub", sub_queue)
194+
195+
try:
196+
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
197+
except (asyncio.TimeoutError, asyncio.CancelledError):
198+
pass
199+
200+
assert sub_queue.qsize() >= 1
201+
msg = await sub_queue.get()
202+
assert msg["e"] == "error"
203+
assert msg["type"] == "RuntimeError"
204+
205+
206+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
207+
@pytest.mark.asyncio
208+
async def test_read_loop_error_reaches_multiple_subscription_queues():
209+
"""Errors should be delivered to ALL registered subscription queues."""
210+
211+
async def recv_side_effect():
212+
raise ConnectionClosedError(None, None)
213+
214+
ws = _make_ws_api_with_mock(recv_side_effect)
215+
queues = [asyncio.Queue() for _ in range(3)]
216+
for i, q in enumerate(queues):
217+
ws.register_subscription_queue(f"sub_{i}", q)
218+
219+
# Set EXITING after error to stop loop
220+
original_propagate = ws._propagate_error
221+
222+
async def propagate_and_exit(error_msg):
223+
await original_propagate(error_msg)
224+
ws.ws_state = WSListenerState.EXITING
225+
226+
ws._propagate_error = propagate_and_exit
227+
228+
try:
229+
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
230+
except (asyncio.TimeoutError, asyncio.CancelledError):
231+
pass
232+
233+
for i, q in enumerate(queues):
234+
assert q.qsize() >= 1, f"Subscription queue {i} should have received the error"
235+
msg = await q.get()
236+
assert msg["e"] == "error"
237+
assert msg["type"] == "ConnectionClosedError"
238+
239+
240+
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
241+
@pytest.mark.asyncio
242+
async def test_normal_messages_not_duplicated_to_main_queue():
243+
"""Normal subscription messages should go to subscription queue only, not main queue."""
244+
call_count = 0
245+
246+
async def recv_side_effect():
247+
nonlocal call_count
248+
call_count += 1
249+
if call_count == 1:
250+
return '{"subscriptionId": "user_sub", "event": {"e": "executionReport", "s": "BTCUSDT"}}'
251+
raise asyncio.CancelledError()
252+
253+
ws = _make_ws_api_with_mock(recv_side_effect)
254+
sub_queue = asyncio.Queue()
255+
ws.register_subscription_queue("user_sub", sub_queue)
256+
257+
try:
258+
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
259+
except (asyncio.TimeoutError, asyncio.CancelledError):
260+
pass
261+
262+
# Normal message should be in subscription queue
263+
assert sub_queue.qsize() >= 1
264+
msg = await sub_queue.get()
265+
assert msg["e"] == "executionReport"
266+
267+
# Main queue should only have the CancelledError, not the normal message
268+
while not ws._queue.empty():
269+
main_msg = await ws._queue.get()
270+
assert main_msg["e"] == "error", "Main queue should only have error messages"

0 commit comments

Comments
 (0)