Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions binance/ws/reconnecting_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
pass

try:
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore
except ImportError:
from websockets import ConnectionClosedError, ConnectionClosedOK # type: ignore

Expand Down Expand Up @@ -78,6 +78,10 @@ def __init__(
self._ws_kwargs = kwargs
self.max_queue_size = max_queue_size

async def _propagate_error(self, error_msg: dict):
"""Put error message on the main queue. Subclasses can override to propagate elsewhere."""
await self._queue.put(error_msg)

def json_dumps(self, msg) -> str:
if orjson:
return orjson.dumps(msg).decode("utf-8")
Expand Down Expand Up @@ -216,7 +220,7 @@ async def _read_loop(self):
# _no_message_received_reconnect
except asyncio.CancelledError as e:
self._log.debug(f"_read_loop cancelled error {e}")
await self._queue.put({
await self._propagate_error({
"e": "error",
"type": f"{e.__class__.__name__}",
"m": f"{e}",
Expand All @@ -231,7 +235,7 @@ async def _read_loop(self):
) as e:
# reports errors and continue loop
self._log.error(f"{e.__class__.__name__} ({e})")
await self._queue.put({
await self._propagate_error({
"e": "error",
"type": f"{e.__class__.__name__}",
"m": f"{e}",
Expand All @@ -243,7 +247,7 @@ async def _read_loop(self):
) as e:
# reports errors and break the loop
self._log.error(f"Unknown exception: {e.__class__.__name__} ({e})")
await self._queue.put({
await self._propagate_error({
"e": "error",
"type": e.__class__.__name__,
"m": f"{e}",
Expand Down
37 changes: 31 additions & 6 deletions binance/ws/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,26 @@


class WebsocketAPI(ReconnectingWebsocket):
def __init__(self, url: str, tld: str = "com", testnet: bool = False, https_proxy: Optional[str] = None):
def __init__(
self,
url: str,
tld: str = "com",
testnet: bool = False,
https_proxy: Optional[str] = None,
):
self._tld = tld
self._testnet = testnet
self._responses: Dict[str, asyncio.Future] = {}
self._connection_lock: Optional[asyncio.Lock] = None
# Subscription queues for routing user data stream events
self._subscription_queues: Dict[str, asyncio.Queue] = {}
super().__init__(url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy)
super().__init__(
url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy
)

def register_subscription_queue(self, subscription_id: str, queue: asyncio.Queue) -> None:
def register_subscription_queue(
self, subscription_id: str, queue: asyncio.Queue
) -> None:
"""Register a queue to receive events for a specific subscription."""
self._subscription_queues[subscription_id] = queue

Expand All @@ -33,6 +43,15 @@ def connection_lock(self) -> asyncio.Lock:
self._connection_lock = asyncio.Lock()
return self._connection_lock

async def _propagate_error(self, error_msg: dict):
"""Propagate error to main queue and all subscription queues."""
await super()._propagate_error(error_msg)
for queue in self._subscription_queues.values():
try:
queue.put_nowait(error_msg)
except asyncio.QueueFull:
self._log.error("Subscription queue full, dropping error message")

def _handle_message(self, msg):
"""Override message handling to support request-response"""
parsed_msg = super()._handle_message(msg)
Expand All @@ -51,9 +70,13 @@ def _handle_message(self, msg):
try:
queue.put_nowait(event)
except asyncio.QueueFull:
self._log.error(f"Subscription queue full for {subscription_id}, dropping event")
self._log.error(
f"Subscription queue full for {subscription_id}, dropping event"
)
except Exception as e:
self._log.error(f"Error putting event in subscription queue for {subscription_id}: {e}")
self._log.error(
f"Error putting event in subscription queue for {subscription_id}: {e}"
)
return None # Don't put in main queue
else:
# No registered queue, return event for main queue (backward compat)
Expand All @@ -65,7 +88,9 @@ def _handle_message(self, msg):
if "status" in parsed_msg:
if parsed_msg["status"] != 200:
exception = BinanceAPIException(
parsed_msg, parsed_msg["status"], self.json_dumps(parsed_msg["error"])
parsed_msg,
parsed_msg["status"],
self.json_dumps(parsed_msg["error"]),
)
if req_id is not None and req_id in self._responses:
if exception is not None:
Expand Down
270 changes: 270 additions & 0 deletions tests/test_error_propagation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""
Tests for issue #1678: Connection errors must be propagated to subscription queues.

Verifies that when _read_loop() catches a connection error, the error dict
is delivered to all registered subscription queues (not just the main queue).
"""

import sys
import asyncio
import pytest
from unittest.mock import AsyncMock, PropertyMock

from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
import websockets.protocol as ws_protocol

from binance.ws.reconnecting_websocket import ReconnectingWebsocket
from binance.ws.websocket_api import WebsocketAPI
from binance.ws.constants import WSListenerState


# -- Unit tests for _propagate_error --


@pytest.mark.asyncio
async def test_base_propagate_error_puts_on_main_queue():
"""ReconnectingWebsocket._propagate_error should put on self._queue."""
ws = ReconnectingWebsocket(url="wss://test.url")
error_msg = {"e": "error", "type": "TestError", "m": "test"}

await ws._propagate_error(error_msg)

assert ws._queue.qsize() == 1
assert await ws._queue.get() == error_msg


@pytest.mark.asyncio
async def test_websocket_api_propagate_error_puts_on_main_and_subscription_queues():
"""WebsocketAPI._propagate_error should put on main queue AND all subscription queues."""
ws = WebsocketAPI(url="wss://test.url")
sub_queue_1 = asyncio.Queue()
sub_queue_2 = asyncio.Queue()
ws.register_subscription_queue("sub1", sub_queue_1)
ws.register_subscription_queue("sub2", sub_queue_2)

error_msg = {"e": "error", "type": "TestError", "m": "test"}
await ws._propagate_error(error_msg)

# Main queue gets the error
assert ws._queue.qsize() == 1
assert await ws._queue.get() == error_msg

# Both subscription queues get the error
assert sub_queue_1.qsize() == 1
assert await sub_queue_1.get() == error_msg
assert sub_queue_2.qsize() == 1
assert await sub_queue_2.get() == error_msg


@pytest.mark.asyncio
async def test_propagate_error_handles_full_subscription_queue():
"""Should not raise when a subscription queue is full."""
ws = WebsocketAPI(url="wss://test.url")
full_queue = asyncio.Queue(maxsize=1)
full_queue.put_nowait({"existing": "msg"}) # Fill it up
ws.register_subscription_queue("sub_full", full_queue)

error_msg = {"e": "error", "type": "TestError", "m": "test"}
await ws._propagate_error(error_msg)

# Main queue still gets it
assert ws._queue.qsize() == 1
# Full subscription queue still has only the old message (error was dropped)
assert full_queue.qsize() == 1
assert (await full_queue.get())["existing"] == "msg"


@pytest.mark.asyncio
async def test_propagate_error_with_no_subscriptions():
"""Should work fine when no subscription queues are registered."""
ws = WebsocketAPI(url="wss://test.url")

error_msg = {"e": "error", "type": "TestError", "m": "test"}
await ws._propagate_error(error_msg)

assert ws._queue.qsize() == 1
assert await ws._queue.get() == error_msg


# -- Integration tests: _read_loop propagates errors to subscription queues --


def _make_ws_api_with_mock(recv_side_effect):
"""Helper: create a WebsocketAPI with a mocked websocket."""
ws = WebsocketAPI(url="wss://test.url")
mock_ws = AsyncMock()
type(mock_ws).state = PropertyMock(return_value=ws_protocol.State.OPEN)
mock_ws.recv = recv_side_effect
mock_ws.close = AsyncMock()
ws.ws = mock_ws
ws.ws_state = WSListenerState.STREAMING
return ws


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
@pytest.mark.asyncio
async def test_read_loop_connection_closed_error_reaches_subscription_queue():
"""ConnectionClosedError in _read_loop should be delivered to subscription queues."""
call_count = 0

async def recv_side_effect():
nonlocal call_count
call_count += 1
if call_count == 1:
raise ConnectionClosedError(None, None)
raise asyncio.CancelledError()

ws = _make_ws_api_with_mock(recv_side_effect)
sub_queue = asyncio.Queue()
ws.register_subscription_queue("user_sub", sub_queue)

try:
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass

assert sub_queue.qsize() >= 1, "Subscription queue should have received the error"
msg = await sub_queue.get()
assert msg["e"] == "error"
assert msg["type"] == "ConnectionClosedError"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
@pytest.mark.asyncio
async def test_read_loop_connection_closed_ok_reaches_subscription_queue():
"""ConnectionClosedOK in _read_loop should be delivered to subscription queues."""
call_count = 0

async def recv_side_effect():
nonlocal call_count
call_count += 1
if call_count == 1:
raise ConnectionClosedOK(None, None)
raise asyncio.CancelledError()

ws = _make_ws_api_with_mock(recv_side_effect)
sub_queue = asyncio.Queue()
ws.register_subscription_queue("user_sub", sub_queue)

try:
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass

assert sub_queue.qsize() >= 1
msg = await sub_queue.get()
assert msg["e"] == "error"
assert msg["type"] == "ConnectionClosedOK"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
@pytest.mark.asyncio
async def test_read_loop_cancelled_error_reaches_subscription_queue():
"""CancelledError in _read_loop should be delivered to subscription queues."""

async def recv_side_effect():
raise asyncio.CancelledError()

ws = _make_ws_api_with_mock(recv_side_effect)
sub_queue = asyncio.Queue()
ws.register_subscription_queue("user_sub", sub_queue)

try:
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass

assert sub_queue.qsize() >= 1
msg = await sub_queue.get()
assert msg["e"] == "error"
assert msg["type"] == "CancelledError"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
@pytest.mark.asyncio
async def test_read_loop_fatal_error_reaches_subscription_queue():
"""Generic exceptions in _read_loop should be delivered to subscription queues."""

async def recv_side_effect():
raise RuntimeError("something broke")

ws = _make_ws_api_with_mock(recv_side_effect)
sub_queue = asyncio.Queue()
ws.register_subscription_queue("user_sub", sub_queue)

try:
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass

assert sub_queue.qsize() >= 1
msg = await sub_queue.get()
assert msg["e"] == "error"
assert msg["type"] == "RuntimeError"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
@pytest.mark.asyncio
async def test_read_loop_error_reaches_multiple_subscription_queues():
"""Errors should be delivered to ALL registered subscription queues."""

async def recv_side_effect():
raise ConnectionClosedError(None, None)

ws = _make_ws_api_with_mock(recv_side_effect)
queues = [asyncio.Queue() for _ in range(3)]
for i, q in enumerate(queues):
ws.register_subscription_queue(f"sub_{i}", q)

# Set EXITING after error to stop loop
original_propagate = ws._propagate_error

async def propagate_and_exit(error_msg):
await original_propagate(error_msg)
ws.ws_state = WSListenerState.EXITING

ws._propagate_error = propagate_and_exit

try:
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass

for i, q in enumerate(queues):
assert q.qsize() >= 1, f"Subscription queue {i} should have received the error"
msg = await q.get()
assert msg["e"] == "error"
assert msg["type"] == "ConnectionClosedError"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
@pytest.mark.asyncio
async def test_normal_messages_not_duplicated_to_main_queue():
"""Normal subscription messages should go to subscription queue only, not main queue."""
call_count = 0

async def recv_side_effect():
nonlocal call_count
call_count += 1
if call_count == 1:
return '{"subscriptionId": "user_sub", "event": {"e": "executionReport", "s": "BTCUSDT"}}'
raise asyncio.CancelledError()

ws = _make_ws_api_with_mock(recv_side_effect)
sub_queue = asyncio.Queue()
ws.register_subscription_queue("user_sub", sub_queue)

try:
await asyncio.wait_for(ws._read_loop(), timeout=3.0)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass

# Normal message should be in subscription queue
assert sub_queue.qsize() >= 1
msg = await sub_queue.get()
assert msg["e"] == "executionReport"

# Main queue should only have the CancelledError, not the normal message
while not ws._queue.empty():
main_msg = await ws._queue.get()
assert main_msg["e"] == "error", "Main queue should only have error messages"
Loading