@@ -764,3 +764,82 @@ def test_plan_invalid_param_mapping(args, kwargs, msg):
764764 with pytest .raises (TypeError , match = msg ):
765765 plan (* args , ** kwargs )
766766 client .run_task .assert_not_called ()
767+
768+
769+ def test_adding_removing_callback (client ):
770+ def callback (* a , ** kw ):
771+ pass
772+
773+ cb_id = client .add_callback (callback )
774+ assert len (client .callbacks ) == 1
775+ client .remove_callback (cb_id )
776+ assert len (client .callbacks ) == 0
777+
778+
779+ @pytest .mark .parametrize (
780+ "test_event" ,
781+ [
782+ WorkerEvent (
783+ state = WorkerState .RUNNING ,
784+ task_status = TaskStatus (
785+ task_id = "foo" ,
786+ task_complete = False ,
787+ task_failed = False ,
788+ ),
789+ ),
790+ ProgressEvent (task_id = "foo" ),
791+ DataEvent (name = "start" , doc = {}, task_id = "0000-1111" ),
792+ ],
793+ )
794+ def test_client_callbacks (
795+ client_with_events : BlueapiClient ,
796+ mock_rest : Mock ,
797+ mock_events : MagicMock ,
798+ test_event : AnyEvent ,
799+ ):
800+ callback = Mock ()
801+ client_with_events .add_callback (callback )
802+ mock_rest .create_task .return_value = TaskResponse (task_id = "foo" )
803+ mock_rest .update_worker_task .return_value = TaskResponse (task_id = "foo" )
804+
805+ ctx = Mock ()
806+ ctx .correlation_id = "foo"
807+
808+ def subscribe (on_event : Callable [[AnyEvent , MessageContext ], None ]):
809+ on_event (test_event , ctx )
810+ on_event (COMPLETE_EVENT , ctx )
811+
812+ mock_events .subscribe_to_all_events = subscribe # type: ignore
813+
814+ client_with_events .run_task (TaskRequest (name = "foo" , instrument_session = "cm12345-1" ))
815+
816+ assert callback .mock_calls == [call (test_event ), call (COMPLETE_EVENT )]
817+
818+
819+ def test_client_callback_failures (
820+ client_with_events : BlueapiClient ,
821+ mock_rest : Mock ,
822+ mock_events : MagicMock ,
823+ ):
824+ failing_callback = Mock (side_effect = ValueError ("Broken callback" ))
825+ callback = Mock ()
826+ client_with_events .add_callback (failing_callback )
827+ client_with_events .add_callback (callback )
828+ mock_rest .create_task .return_value = TaskResponse (task_id = "foo" )
829+ mock_rest .update_worker_task .return_value = TaskResponse (task_id = "foo" )
830+
831+ ctx = Mock ()
832+ ctx .correlation_id = "foo"
833+
834+ evt = DataEvent (name = "start" , doc = {}, task_id = "foo" )
835+
836+ def subscribe (on_event : Callable [[AnyEvent , MessageContext ], None ]):
837+ on_event (evt , ctx )
838+ on_event (COMPLETE_EVENT , ctx )
839+
840+ mock_events .subscribe_to_all_events = subscribe # type: ignore
841+
842+ client_with_events .run_task (TaskRequest (name = "foo" , instrument_session = "cm12345-1" ))
843+
844+ assert failing_callback .mock_calls == [call (evt ), call (COMPLETE_EVENT )]
845+ assert callback .mock_calls == [call (evt ), call (COMPLETE_EVENT )]
0 commit comments