Skip to content

Commit 93d5d85

Browse files
committed
Add callback support
1 parent 0a3edcd commit 93d5d85

2 files changed

Lines changed: 108 additions & 0 deletions

File tree

src/blueapi/client/client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import itertools
12
import logging
23
import time
4+
from collections.abc import Iterable
35
from concurrent.futures import Future
46
from functools import cached_property
57
from itertools import chain
@@ -196,6 +198,8 @@ class BlueapiClient:
196198
_rest: BlueapiRestClient
197199
_events: EventBusClient | None
198200
_instrument_session: str | None = None
201+
_callbacks: dict[int, OnAnyEvent]
202+
_callback_id: itertools.count
199203

200204
def __init__(
201205
self,
@@ -204,6 +208,8 @@ def __init__(
204208
):
205209
self._rest = rest
206210
self._events = events
211+
self._callbacks = {}
212+
self._callback_id = itertools.count()
207213

208214
@cached_property
209215
@start_as_current_span(TRACER)
@@ -258,6 +264,22 @@ def instrument_session(self, session: str):
258264
log.debug("Setting instrument_session to %s", session)
259265
self._instrument_session = session
260266

267+
def with_instrument_session(self, session: str) -> Self:
268+
self.instrument_session = session
269+
return self
270+
271+
def add_callback(self, callback: OnAnyEvent) -> int:
272+
cb_id = next(self._callback_id)
273+
self._callbacks[cb_id] = callback
274+
return cb_id
275+
276+
def remove_callback(self, id: int):
277+
self._callbacks.pop(id)
278+
279+
@property
280+
def callbacks(self) -> Iterable[OnAnyEvent]:
281+
return self._callbacks.values()
282+
261283
@property
262284
@start_as_current_span(TRACER)
263285
def state(self) -> WorkerState:
@@ -355,6 +377,13 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None:
355377
if relates_to_task:
356378
if on_event is not None:
357379
on_event(event)
380+
for cb in self._callbacks.values():
381+
try:
382+
cb(event)
383+
except Exception as e:
384+
log.error(
385+
f"Callback ({cb}) failed for event: {event}", exc_info=e
386+
)
358387
if isinstance(event, WorkerEvent) and (
359388
(event.is_complete()) and (ctx.correlation_id == task_id)
360389
):

tests/unit_tests/client/test_client.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,3 +764,82 @@ def test_plan_invalid_param_mapping(args, kwargs, msg):
764764
with pytest.raises(TypeError, match=msg):
765765
plan(*args, **kwargs)
766766
client.run_task.assert_not_called()
767+
768+
769+
def test_adding_removing_callback(client):
770+
def callback(*a, **kw):
771+
pass
772+
773+
cb_id = client.add_callback(callback)
774+
assert len(client.callbacks) == 1
775+
client.remove_callback(cb_id)
776+
assert len(client.callbacks) == 0
777+
778+
779+
@pytest.mark.parametrize(
780+
"test_event",
781+
[
782+
WorkerEvent(
783+
state=WorkerState.RUNNING,
784+
task_status=TaskStatus(
785+
task_id="foo",
786+
task_complete=False,
787+
task_failed=False,
788+
),
789+
),
790+
ProgressEvent(task_id="foo"),
791+
DataEvent(name="start", doc={}, task_id="0000-1111"),
792+
],
793+
)
794+
def test_client_callbacks(
795+
client_with_events: BlueapiClient,
796+
mock_rest: Mock,
797+
mock_events: MagicMock,
798+
test_event: AnyEvent,
799+
):
800+
callback = Mock()
801+
client_with_events.add_callback(callback)
802+
mock_rest.create_task.return_value = TaskResponse(task_id="foo")
803+
mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo")
804+
805+
ctx = Mock()
806+
ctx.correlation_id = "foo"
807+
808+
def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]):
809+
on_event(test_event, ctx)
810+
on_event(COMPLETE_EVENT, ctx)
811+
812+
mock_events.subscribe_to_all_events = subscribe # type: ignore
813+
814+
client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1"))
815+
816+
assert callback.mock_calls == [call(test_event), call(COMPLETE_EVENT)]
817+
818+
819+
def test_client_callback_failures(
820+
client_with_events: BlueapiClient,
821+
mock_rest: Mock,
822+
mock_events: MagicMock,
823+
):
824+
failing_callback = Mock(side_effect=ValueError("Broken callback"))
825+
callback = Mock()
826+
client_with_events.add_callback(failing_callback)
827+
client_with_events.add_callback(callback)
828+
mock_rest.create_task.return_value = TaskResponse(task_id="foo")
829+
mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo")
830+
831+
ctx = Mock()
832+
ctx.correlation_id = "foo"
833+
834+
evt = DataEvent(name="start", doc={}, task_id="foo")
835+
836+
def subscribe(on_event: Callable[[AnyEvent, MessageContext], None]):
837+
on_event(evt, ctx)
838+
on_event(COMPLETE_EVENT, ctx)
839+
840+
mock_events.subscribe_to_all_events = subscribe # type: ignore
841+
842+
client_with_events.run_task(TaskRequest(name="foo", instrument_session="cm12345-1"))
843+
844+
assert failing_callback.mock_calls == [call(evt), call(COMPLETE_EVENT)]
845+
assert callback.mock_calls == [call(evt), call(COMPLETE_EVENT)]

0 commit comments

Comments
 (0)