diff --git a/src/haclient/core/events.py b/src/haclient/core/events.py index b31ec8d..2ff3dc6 100644 --- a/src/haclient/core/events.py +++ b/src/haclient/core/events.py @@ -18,6 +18,8 @@ from __future__ import annotations +import asyncio +import contextlib import logging from collections import defaultdict, deque from collections.abc import Awaitable @@ -35,6 +37,19 @@ class EventBus: ---------- ws : WebSocketPort The transport used to subscribe to Home Assistant events. + + Notes + ----- + Subscriptions can fail at the transport layer. The bus exposes two + APIs to handle this: + + * `subscribe` / `unsubscribe` — fire-and-forget; failures are logged + and recorded on the bus so callers can inspect them via + `subscription_failure` and `pending_subscription`. + * `subscribe_async` / `unsubscribe_async` — awaitable; transport + errors are raised to the caller. Prefer these whenever the caller + needs confirmation that Home Assistant actually accepted the + subscription. """ def __init__(self, ws: WebSocketPort) -> None: @@ -43,13 +58,22 @@ def __init__(self, ws: WebSocketPort) -> None: self._subscription_ids: dict[str, int] = {} self._buffers: dict[str, deque[dict[str, Any]]] = {} self._started = False + # Per-event-type background subscription task scheduled by the + # fire-and-forget `subscribe()` path. Replaced (and the previous + # task discarded) on each new attempt. + self._pending_subscriptions: dict[str, asyncio.Task[None]] = {} + # Last subscription failure observed for a given event type via + # the fire-and-forget path. Cleared on a successful retry. + self._subscription_failures: dict[str, BaseException] = {} def subscribe(self, event_type: str, handler: EventHandler) -> EventHandler: """Register *handler* for the given *event_type*. Subscriptions registered before `start` are batched; those added afterwards trigger an immediate WebSocket subscribe if it is the - first handler for the type. + first handler for the type. The scheduled task is tracked so + callers can await it (`pending_subscription`) or inspect its + outcome (`subscription_failure`). Parameters ---------- @@ -62,21 +86,75 @@ def subscribe(self, event_type: str, handler: EventHandler) -> EventHandler: ------- callable The same *handler*, for use as a decorator. + + Notes + ----- + This method does not raise transport errors. Use + `subscribe_async` when the caller needs to know whether Home + Assistant accepted the subscription. """ first_for_type = event_type not in self._handlers self._handlers[event_type].append(handler) if self._started and first_for_type: # Subscribe lazily; the WS adapter handles re-subscription on reconnect. - import asyncio + task = asyncio.ensure_future(self._ensure_subscription(event_type)) + self._pending_subscriptions[event_type] = task + + def _done(t: asyncio.Task[None], et: str = event_type) -> None: + self._on_subscription_task_done(et, t) + + task.add_done_callback(_done) + return handler + + async def subscribe_async(self, event_type: str, handler: EventHandler) -> EventHandler: + """Register *handler* and await the underlying WebSocket subscribe. + + Like `subscribe`, but transport failures propagate to the caller + and the handler is rolled back if the first subscribe for an + event type fails — so callers can rely on the returned handler + being live. + + Parameters + ---------- + event_type : str + The Home Assistant event type. + handler : callable + Sync or async callable receiving the event dict. + + Returns + ------- + callable + The registered handler. - asyncio.ensure_future(self._ensure_subscription(event_type)) + Raises + ------ + Exception + Any exception raised by the underlying `WebSocketPort` when + the first handler for *event_type* triggers a subscribe. + """ + first_for_type = event_type not in self._handlers + self._handlers[event_type].append(handler) + if not (self._started and first_for_type): + return handler + try: + await self._subscribe_now(event_type) + except BaseException: + # Roll back the handler so the caller's view is consistent + # with the transport state. + handlers = self._handlers.get(event_type) + if handlers is not None: + with contextlib.suppress(ValueError): # pragma: no cover - defensive + handlers.remove(handler) + if not handlers: + self._handlers.pop(event_type, None) + raise return handler def unsubscribe(self, event_type: str, handler: EventHandler) -> None: """Remove a previously registered handler. If the last handler for *event_type* is removed the WebSocket - subscription is also cancelled. + subscription is also cancelled in the background. Parameters ---------- @@ -86,20 +164,105 @@ def unsubscribe(self, event_type: str, handler: EventHandler) -> None: The exact handler previously passed to `subscribe`. Removing an unknown handler is a no-op. """ + sub_id = self._drop_handler(event_type, handler) + if sub_id is not None and self._ws.connected: + asyncio.ensure_future(self._safe_unsubscribe(sub_id)) + + async def unsubscribe_async(self, event_type: str, handler: EventHandler) -> None: + """Remove a handler and await any resulting WebSocket unsubscribe. + + Unlike `unsubscribe`, transport errors raised while telling Home + Assistant to stop sending events propagate to the caller. + + Parameters + ---------- + event_type : str + The Home Assistant event type to unsubscribe from. + handler : callable + The exact handler previously passed to `subscribe` or + `subscribe_async`. Removing an unknown handler is a no-op. + + Raises + ------ + Exception + Any exception raised by `WebSocketPort.unsubscribe`. + """ + sub_id = self._drop_handler(event_type, handler) + if sub_id is not None and self._ws.connected: + await self._ws.unsubscribe(sub_id) + + def _drop_handler(self, event_type: str, handler: EventHandler) -> int | None: + """Remove *handler* and return the WS subscription id to release. + + Returns ``None`` if the handler was unknown or other handlers + remain for *event_type*. + """ handlers = self._handlers.get(event_type) if not handlers: - return + return None try: handlers.remove(handler) except ValueError: - return - if not handlers: - self._handlers.pop(event_type, None) - sub_id = self._subscription_ids.pop(event_type, None) - if sub_id is not None and self._ws.connected: - import asyncio + return None + if handlers: + return None + self._handlers.pop(event_type, None) + return self._subscription_ids.pop(event_type, None) - asyncio.ensure_future(self._safe_unsubscribe(sub_id)) + def subscription_failure(self, event_type: str) -> BaseException | None: + """Return the last fire-and-forget subscribe failure, if any. + + Parameters + ---------- + event_type : str + The event type to inspect. + + Returns + ------- + BaseException or None + The exception raised by the most recent fire-and-forget + subscribe attempt, or ``None`` if the current subscription + is healthy (or no attempt has been made). + """ + return self._subscription_failures.get(event_type) + + def pending_subscription(self, event_type: str) -> asyncio.Task[None] | None: + """Return the in-flight subscribe task for *event_type*, if any. + + Awaiting the returned task lets callers convert a fire-and-forget + `subscribe` into a confirmed registration without changing the + original call site. + + Parameters + ---------- + event_type : str + The event type whose pending subscribe task to return. + + Returns + ------- + asyncio.Task or None + The scheduled task, or ``None`` if no subscribe is in flight + for *event_type*. + """ + task = self._pending_subscriptions.get(event_type) + if task is None or task.done(): + return None + return task + + def _on_subscription_task_done(self, event_type: str, task: asyncio.Task[None]) -> None: + """Record the outcome of a fire-and-forget subscribe task.""" + # Only forget the task if it is still the registered one — a + # later attempt may have replaced it. + if self._pending_subscriptions.get(event_type) is task: + self._pending_subscriptions.pop(event_type, None) + if task.cancelled(): + return + exc = task.exception() + if exc is None: + # Success: clear any stale failure. + self._subscription_failures.pop(event_type, None) + else: + self._subscription_failures[event_type] = exc async def _safe_unsubscribe(self, sub_id: int) -> None: """Unsubscribe, swallowing transport errors.""" @@ -112,22 +275,51 @@ async def start(self) -> None: """Subscribe to every registered event type and arm reconnect. Safe to call multiple times. + + Notes + ----- + Transport failures during the initial batch are recorded on the + bus (see `subscription_failure`) but not raised, so a single + flaky event type does not abort startup. Use `subscribe_async` + afterwards if you need confirmation that a specific subscription + is live. """ if self._started: return for event_type in list(self._handlers.keys()): - await self._ensure_subscription(event_type) + try: + await self._ensure_subscription(event_type) + except Exception: # noqa: BLE001 - logged & recorded by _ensure_subscription + continue self._started = True async def _ensure_subscription(self, event_type: str) -> None: - """Subscribe on the WS if not already subscribed.""" - if event_type in self._subscription_ids: - return + """Subscribe on the WS, logging transport errors and recording them. + + Used by `start` and by the fire-and-forget `subscribe` path. The + exception is re-raised so the post-start path's task done + callback can store it on `_subscription_failures`; the `start` + path catches it again to preserve the historic "start never + raises on subscribe failure" behaviour, but still records the + failure for observability. + """ try: - sub_id = await self._ws.subscribe_events(self._make_dispatcher(event_type), event_type) - except Exception: + await self._subscribe_now(event_type) + except Exception as exc: _LOGGER.exception("EventBus failed to subscribe to %s", event_type) + self._subscription_failures[event_type] = exc + raise + + async def _subscribe_now(self, event_type: str) -> None: + """Subscribe on the WS, propagating transport errors. + + Used by `subscribe_async` and (indirectly) by + `_ensure_subscription`. Idempotent — returns immediately if a + subscription id is already recorded for *event_type*. + """ + if event_type in self._subscription_ids: return + sub_id = await self._ws.subscribe_events(self._make_dispatcher(event_type), event_type) self._subscription_ids[event_type] = sub_id def _make_dispatcher(self, event_type: str) -> EventHandler: diff --git a/tests/test_events.py b/tests/test_events.py index 5daec86..95ccffc 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -183,14 +183,206 @@ async def test_double_start_is_noop() -> None: assert len(ws.subscriptions) == 1 -async def test_subscribe_failure_logged_not_propagated() -> None: +async def test_subscribe_failure_during_start_is_recorded() -> None: + """Pre-start subscriptions failing during `start` must be observable. + + The historical contract is that `start` does not raise so a single + flaky event type cannot abort the whole client. The new contract is + that the failure is recorded on the bus so callers can inspect it. + """ ws = _FakeWS() ws.subscribe_failure = RuntimeError("boom") bus = EventBus(ws) # type: ignore[arg-type] bus.subscribe("ev", lambda e: None) - await bus.start() + await bus.start() # must not raise # No id recorded because the WS subscribe raised. assert ws.subscriptions == {} + failure = bus.subscription_failure("ev") + assert isinstance(failure, RuntimeError) + assert str(failure) == "boom" + + +async def test_subscribe_post_start_failure_is_observable() -> None: + """A post-start `subscribe` that fails must surface via the API.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + await bus.start() + ws.subscribe_failure = RuntimeError("nope") + bus.subscribe("ev", lambda e: None) + pending = bus.pending_subscription("ev") + assert pending is not None + # The task completes with the transport error and the failure is + # exposed via `subscription_failure`. + with pytest.raises(RuntimeError, match="nope"): + await pending + failure = bus.subscription_failure("ev") + assert isinstance(failure, RuntimeError) + assert ws.subscriptions == {} + + +async def test_subscribe_post_start_recovers_on_retry() -> None: + """A successful subsequent attempt clears the recorded failure.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + await bus.start() + ws.subscribe_failure = RuntimeError("transient") + + def handler(_e: dict[str, Any]) -> None: + return None + + bus.subscribe("ev", handler) + pending = bus.pending_subscription("ev") + assert pending is not None + with pytest.raises(RuntimeError): + await pending + assert bus.subscription_failure("ev") is not None + + # Drop and retry with the transport healthy. + bus.unsubscribe("ev", handler) + ws.subscribe_failure = None + bus.subscribe("ev", handler) + retry = bus.pending_subscription("ev") + assert retry is not None + await retry + assert bus.subscription_failure("ev") is None + assert len(ws.subscriptions) == 1 + + +async def test_subscribe_async_propagates_failure_and_rolls_back() -> None: + """`subscribe_async` raises and leaves no dangling handler.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + await bus.start() + ws.subscribe_failure = RuntimeError("denied") + + def handler(_e: dict[str, Any]) -> None: + return None + + with pytest.raises(RuntimeError, match="denied"): + await bus.subscribe_async("ev", handler) + + # Handler must NOT remain registered after a failed first subscribe, + # otherwise the caller's view diverges from the transport state. + assert ws.subscriptions == {} + ws.subscribe_failure = None + # A subsequent subscribe is treated as a fresh "first handler" and + # actually performs the WS call. + await bus.subscribe_async("ev", handler) + assert len(ws.subscriptions) == 1 + + +async def test_subscribe_async_before_start_is_batched() -> None: + """`subscribe_async` registered before `start` does not hit the WS.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + await bus.subscribe_async("ev", lambda _e: None) + assert ws.subscriptions == {} + await bus.start() + assert len(ws.subscriptions) == 1 + + +async def test_subscribe_async_additional_handler_skips_ws_call() -> None: + """Adding a second handler for an event type re-uses the subscription.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + await bus.start() + await bus.subscribe_async("ev", lambda _e: None) + assert len(ws.subscriptions) == 1 + # A second handler must succeed without raising even if a fresh + # subscribe would now fail; it should not call the WS at all. + ws.subscribe_failure = RuntimeError("would fail if called") + await bus.subscribe_async("ev", lambda _e: None) + assert len(ws.subscriptions) == 1 + + +async def test_unsubscribe_async_propagates_transport_errors() -> None: + """`unsubscribe_async` raises if the WS rejects the call.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + + def handler(_e: dict[str, Any]) -> None: + return None + + bus.subscribe("ev", handler) + await bus.start() + + original = ws.unsubscribe + + async def failing_unsubscribe(sub_id: int) -> None: + raise RuntimeError("nope") + + ws.unsubscribe = failing_unsubscribe # type: ignore[method-assign] + try: + with pytest.raises(RuntimeError, match="nope"): + await bus.unsubscribe_async("ev", handler) + finally: + ws.unsubscribe = original # type: ignore[method-assign] + + +async def test_unsubscribe_async_unknown_handler_is_noop() -> None: + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + await bus.unsubscribe_async("ev", lambda _e: None) # never registered + bus.subscribe("ev", lambda _e: None) + await bus.unsubscribe_async("ev", lambda _e: None) # different instance + + +async def test_pending_subscription_none_when_idle() -> None: + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + assert bus.pending_subscription("ev") is None + bus.subscribe("ev", lambda _e: None) + await bus.start() + # Subscribe was done synchronously inside start(), no pending task. + assert bus.pending_subscription("ev") is None + + +async def test_unsubscribe_async_does_not_release_ws_when_handlers_remain() -> None: + """Removing one of several handlers must not tear down the WS subscription.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + + def h1(_e: dict[str, Any]) -> None: + return None + + def h2(_e: dict[str, Any]) -> None: + return None + + bus.subscribe("ev", h1) + bus.subscribe("ev", h2) + await bus.start() + await bus.unsubscribe_async("ev", h1) + assert ws.unsubscribed == [] # WS subscription still active for h2 + assert len(ws.subscriptions) == 1 + + +async def test_cancelled_subscription_task_clears_pending() -> None: + """A cancelled post-start subscribe task is forgotten cleanly.""" + ws = _FakeWS() + bus = EventBus(ws) # type: ignore[arg-type] + + # Slow down the WS so we can cancel the in-flight task. + started = asyncio.Event() + release = asyncio.Event() + + async def slow_subscribe(handler: Any, event_type: str | None = None) -> int: + started.set() + await release.wait() + return 1 + + ws.subscribe_events = slow_subscribe # type: ignore[method-assign] + await bus.start() + bus.subscribe("ev", lambda _e: None) + task = bus.pending_subscription("ev") + assert task is not None + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + # No failure recorded for a cancellation; the slot is cleared. + assert bus.subscription_failure("ev") is None + assert bus.pending_subscription("ev") is None + release.set() async def test_install_reconnect_hook_invokes_callback() -> None: