Skip to content

Commit b338722

Browse files
committed
Address a potential issue in the PybricksHubUSB.write_gatt_char method. Previously, there was a concern that this part of the code could get stuck if the USB hub disconnected or didn't send a response.
Here's how I've addressed it: - I've added a 5-second timeout for waiting for a response from the hub. - I'm now also monitoring for a hub disconnection while waiting for the response. If the hub disconnects, a `RuntimeError` will occur. If the operation times out, an `asyncio.TimeoutError` will occur. I've also included some checks in `tests/connections/test_pybricks.py` to ensure this new behavior works as expected in both disconnection and timeout situations.
1 parent e0c8dad commit b338722

2 files changed

Lines changed: 104 additions & 5 deletions

File tree

pybricksdev/connections/pybricks.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -837,10 +837,17 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
837837
raise ValueError("Response is required for USB")
838838

839839
self._ep_out.write(bytes([PybricksUsbOutEpMessageType.COMMAND]) + data)
840-
# FIXME: This needs to race with hub disconnect, and could also use a
841-
# timeout, otherwise it blocks forever. Pyusb doesn't currently seem to
842-
# have any disconnect callback.
843-
reply = await self._response_queue.get()
840+
841+
try:
842+
reply = await asyncio.wait_for(
843+
self.race_disconnect(self._response_queue.get()),
844+
timeout=5.0, # 5-second timeout
845+
)
846+
except asyncio.TimeoutError:
847+
# Handle timeout specifically if needed, or let race_disconnect handle it
848+
# For now, let's make it explicit
849+
raise asyncio.TimeoutError("Timeout waiting for USB response")
850+
# race_disconnect will raise RuntimeError if disconnected
844851

845852
# REVISIT: could look up status error code and convert to string,
846853
# although BLE doesn't do that either.

tests/connections/test_pybricks.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44
import contextlib
55
import os
66
import tempfile
7-
from unittest.mock import AsyncMock, PropertyMock, patch
7+
from unittest.mock import AsyncMock, PropertyMock, patch, MagicMock
88

99
import pytest
1010
from reactivex.subject import Subject
1111

12+
13+
from pybricksdev.ble.pybricks import PYBRICKS_COMMAND_EVENT_UUID
14+
from pybricksdev.usb.pybricks import PybricksUsbOutEpMessageType
1215
from pybricksdev.connections.pybricks import (
1316
ConnectionState,
1417
HubCapabilityFlag,
1518
HubKind,
1619
PybricksHubBLE,
20+
PybricksHubUSB,
1721
StatusFlag,
1822
)
1923

@@ -180,3 +184,91 @@ async def test_run_modern_protocol(self):
180184
# Verify the expected calls were made
181185
hub.download_user_program.assert_called_once()
182186
hub.start_user_program.assert_called_once()
187+
188+
189+
class TestPybricksHubUSB:
190+
"""Tests for the PybricksHubUSB class functionality."""
191+
192+
@pytest.mark.asyncio
193+
async def test_pybricks_hub_usb_write_gatt_char_disconnect(self):
194+
"""Test write_gatt_char when a disconnect event occurs."""
195+
hub = PybricksHubUSB(MagicMock())
196+
197+
hub._ep_out = MagicMock()
198+
# Simulate _response_queue.get() blocking indefinitely
199+
hub._response_queue = AsyncMock()
200+
hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait)
201+
202+
mock_observable = MagicMock(
203+
spec=Subject
204+
) # Using Subject as a base for mock spec
205+
disconnect_callback_handler = None
206+
207+
def mock_subscribe_side_effect(on_next_callback, *args, **kwargs):
208+
nonlocal disconnect_callback_handler
209+
disconnect_callback_handler = on_next_callback
210+
mock_subscription = MagicMock()
211+
mock_subscription.dispose = MagicMock()
212+
return mock_subscription
213+
214+
mock_observable.subscribe = MagicMock(side_effect=mock_subscribe_side_effect)
215+
type(hub.connection_state_observable).value = PropertyMock(
216+
return_value=ConnectionState.CONNECTED
217+
)
218+
hub.connection_state_observable = mock_observable
219+
220+
async def trigger_disconnect_event():
221+
await asyncio.sleep(0.05)
222+
assert (
223+
disconnect_callback_handler is not None
224+
), "Subscribe was not called by race_disconnect"
225+
disconnect_callback_handler(ConnectionState.DISCONNECTED)
226+
227+
with pytest.raises(RuntimeError, match="disconnected during operation"):
228+
await asyncio.gather(
229+
hub.write_gatt_char(PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True),
230+
trigger_disconnect_event(),
231+
)
232+
233+
hub._ep_out.write.assert_called_once_with(
234+
bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data"
235+
)
236+
237+
@pytest.mark.asyncio
238+
async def test_pybricks_hub_usb_write_gatt_char_timeout(self):
239+
"""Test write_gatt_char when a timeout occurs."""
240+
hub = PybricksHubUSB(MagicMock())
241+
242+
hub._ep_out = MagicMock()
243+
hub._response_queue = AsyncMock()
244+
# Make _response_queue.get() block indefinitely
245+
hub._response_queue.get = AsyncMock(side_effect=asyncio.Event().wait)
246+
247+
mock_observable = MagicMock(spec=Subject)
248+
249+
def mock_subscribe_side_effect(on_next_callback, *args, **kwargs):
250+
mock_subscription = MagicMock()
251+
mock_subscription.dispose = MagicMock()
252+
return mock_subscription
253+
254+
mock_observable.subscribe = MagicMock(side_effect=mock_subscribe_side_effect)
255+
type(hub.connection_state_observable).value = PropertyMock(
256+
return_value=ConnectionState.CONNECTED
257+
)
258+
hub.connection_state_observable = mock_observable
259+
260+
# The method has a hardcoded timeout of 5.0s.
261+
# We can patch asyncio.wait_for to speed up the test.
262+
with patch(
263+
"asyncio.wait_for", side_effect=asyncio.TimeoutError("Test-induced timeout")
264+
):
265+
with pytest.raises(
266+
asyncio.TimeoutError, match="Timeout waiting for USB response"
267+
):
268+
await hub.write_gatt_char(
269+
PYBRICKS_COMMAND_EVENT_UUID, b"test_data", True
270+
)
271+
272+
hub._ep_out.write.assert_called_once_with(
273+
bytes([PybricksUsbOutEpMessageType.COMMAND]) + b"test_data"
274+
)

0 commit comments

Comments
 (0)