2222 Client ,
2323 WorkflowExecution ,
2424)
25- from durable_workflow .errors import ServerError
25+ from durable_workflow .errors import InvalidArgument , ServerError , Unauthorized , WorkflowNotFound
2626from durable_workflow .interceptors import (
2727 ActivityHandler ,
2828 ActivityInterceptorContext ,
3232 WorkflowTaskHandler ,
3333 WorkflowTaskInterceptorContext ,
3434)
35- from durable_workflow .worker import Worker
35+ from durable_workflow .worker import Worker , _should_fail_workflow_task_after_completion_error
3636
3737
3838@workflow .defn (name = "test-wf" )
@@ -185,6 +185,28 @@ def compatible_cluster_info(**overrides: object) -> dict[str, object]:
185185 return info
186186
187187
188+ class TestWorkflowTaskCompletionErrorClassification :
189+ @pytest .mark .parametrize (
190+ ("error" , "should_fail" ),
191+ [
192+ (TimeoutError ("completion timed out" ), False ),
193+ (RuntimeError ("connection reset" ), False ),
194+ (ServerError (409 , {"reason" : "lease_expired" }), False ),
195+ (ServerError (409 , {"reason" : "workflow_task_attempt_mismatch" }), False ),
196+ (ServerError (429 , {"reason" : "rate_limited" }), False ),
197+ (ServerError (503 , {"reason" : "server_busy" }), False ),
198+ (ServerError (409 , {"reason" : "invalid_commands" }), True ),
199+ (InvalidArgument ("invalid command payload" ), True ),
200+ (Unauthorized ("missing bearer token" ), True ),
201+ (WorkflowNotFound ("wf-missing" ), True ),
202+ ],
203+ )
204+ def test_classifies_definite_and_ambiguous_completion_errors (
205+ self , error : BaseException , should_fail : bool
206+ ) -> None :
207+ assert _should_fail_workflow_task_after_completion_error (error ) is should_fail
208+
209+
188210class TestWorkerRegistration :
189211 @pytest .mark .asyncio
190212 async def test_register (self , mock_client : AsyncMock ) -> None :
@@ -502,7 +524,7 @@ async def test_schedule_activity_on_first_replay(self, mock_client: AsyncMock) -
502524 assert serializer .decode (commands [0 ]["arguments" ]["blob" ], codec = "json" ) == ["hello" ]
503525
504526 @pytest .mark .asyncio
505- async def test_workflow_task_completion_error_fails_task_for_fast_redispatch (
527+ async def test_workflow_task_ambiguous_completion_error_preserves_commands (
506528 self , mock_client : AsyncMock
507529 ) -> None :
508530 mock_client .complete_workflow_task .side_effect = TimeoutError ("completion timed out" )
@@ -518,15 +540,96 @@ async def test_workflow_task_completion_error_fails_task_for_fast_redispatch(
518540
519541 result = await worker ._run_workflow_task (task )
520542
543+ assert result is not None
544+ assert result [0 ]["type" ] == "schedule_activity"
545+ mock_client .complete_workflow_task .assert_awaited_once ()
546+ mock_client .fail_workflow_task .assert_not_called ()
547+
548+ @pytest .mark .asyncio
549+ async def test_workflow_task_definite_completion_rejection_fails_task (
550+ self , mock_client : AsyncMock
551+ ) -> None :
552+ mock_client .complete_workflow_task .side_effect = ServerError (409 , {"reason" : "invalid_commands" })
553+ worker = Worker (mock_client , task_queue = "q1" , workflows = [TestWorkflow ], activities = [])
554+ task = {
555+ "task_id" : "t-complete-invalid" ,
556+ "workflow_type" : "test-wf" ,
557+ "workflow_task_attempt" : 2 ,
558+ "history_events" : [],
559+ "arguments" : '["hello"]' ,
560+ "payload_codec" : "json" ,
561+ }
562+
563+ result = await worker ._run_workflow_task (task )
564+
565+ assert result is None
566+ mock_client .fail_workflow_task .assert_awaited_once ()
567+ call_kwargs = mock_client .fail_workflow_task .await_args .kwargs
568+ assert call_kwargs ["task_id" ] == "t-complete-invalid"
569+ assert call_kwargs ["workflow_task_attempt" ] == 2
570+ assert call_kwargs ["lease_owner" ] == worker .worker_id
571+ assert call_kwargs ["failure_type" ] == "ServerError"
572+ assert "invalid_commands" in call_kwargs ["message" ]
573+
574+ @pytest .mark .parametrize (
575+ ("completion_error" , "failure_type" , "message_fragment" ),
576+ [
577+ (Unauthorized ("missing bearer token" ), "Unauthorized" , "missing bearer token" ),
578+ (WorkflowNotFound ("wf-typed-missing" ), "WorkflowNotFound" , "wf-typed-missing" ),
579+ ],
580+ )
581+ @pytest .mark .asyncio
582+ async def test_workflow_task_typed_completion_rejection_fails_task (
583+ self ,
584+ mock_client : AsyncMock ,
585+ completion_error : Exception ,
586+ failure_type : str ,
587+ message_fragment : str ,
588+ ) -> None :
589+ mock_client .complete_workflow_task .side_effect = completion_error
590+ worker = Worker (mock_client , task_queue = "q1" , workflows = [TestWorkflow ], activities = [])
591+ task = {
592+ "task_id" : "t-complete-typed-rejection" ,
593+ "workflow_type" : "test-wf" ,
594+ "workflow_task_attempt" : 2 ,
595+ "history_events" : [],
596+ "arguments" : '["hello"]' ,
597+ "payload_codec" : "json" ,
598+ }
599+
600+ result = await worker ._run_workflow_task (task )
601+
521602 assert result is None
522603 mock_client .complete_workflow_task .assert_awaited_once ()
523604 mock_client .fail_workflow_task .assert_awaited_once ()
524605 call_kwargs = mock_client .fail_workflow_task .await_args .kwargs
525- assert call_kwargs ["task_id" ] == "t-complete-timeout "
606+ assert call_kwargs ["task_id" ] == "t-complete-typed-rejection "
526607 assert call_kwargs ["workflow_task_attempt" ] == 2
527608 assert call_kwargs ["lease_owner" ] == worker .worker_id
528- assert call_kwargs ["failure_type" ] == "TimeoutError"
529- assert "completion timed out" in call_kwargs ["message" ]
609+ assert call_kwargs ["failure_type" ] == failure_type
610+ assert message_fragment in call_kwargs ["message" ]
611+
612+ @pytest .mark .asyncio
613+ async def test_workflow_task_definite_completion_rejection_stays_failed_when_report_fails (
614+ self , mock_client : AsyncMock
615+ ) -> None :
616+ mock_client .complete_workflow_task .side_effect = ServerError (409 , {"reason" : "invalid_commands" })
617+ mock_client .fail_workflow_task .side_effect = RuntimeError ("failure report unavailable" )
618+ worker = Worker (mock_client , task_queue = "q1" , workflows = [TestWorkflow ], activities = [])
619+ task = {
620+ "task_id" : "t-complete-invalid-report-fails" ,
621+ "workflow_type" : "test-wf" ,
622+ "workflow_task_attempt" : 2 ,
623+ "history_events" : [],
624+ "arguments" : '["hello"]' ,
625+ "payload_codec" : "json" ,
626+ }
627+
628+ result = await worker ._run_workflow_task (task )
629+
630+ assert result is None
631+ mock_client .complete_workflow_task .assert_awaited_once ()
632+ mock_client .fail_workflow_task .assert_awaited_once ()
530633
531634 @pytest .mark .asyncio
532635 async def test_workflow_command_payload_warning_uses_client_policy (
@@ -701,7 +804,7 @@ async def test_update_backed_workflow_task_completes_update_command(
701804 mock_client .fail_workflow_task .assert_not_called ()
702805
703806 @pytest .mark .asyncio
704- async def test_update_task_completion_error_fails_task_for_fast_redispatch (
807+ async def test_update_task_ambiguous_completion_error_preserves_command (
705808 self , mock_client : AsyncMock
706809 ) -> None :
707810 mock_client .complete_workflow_task .side_effect = TimeoutError ("update completion timed out" )
@@ -729,14 +832,10 @@ async def test_update_task_completion_error_fails_task_for_fast_redispatch(
729832
730833 result = await worker ._run_workflow_task (task )
731834
732- assert result is None
835+ assert result is not None
836+ assert result [0 ]["type" ] == "complete_update"
733837 mock_client .complete_workflow_task .assert_awaited_once ()
734- mock_client .fail_workflow_task .assert_awaited_once ()
735- call_kwargs = mock_client .fail_workflow_task .await_args .kwargs
736- assert call_kwargs ["task_id" ] == "t-update-timeout"
737- assert call_kwargs ["workflow_task_attempt" ] == 3
738- assert call_kwargs ["failure_type" ] == "TimeoutError"
739- assert "update completion timed out" in call_kwargs ["message" ]
838+ mock_client .fail_workflow_task .assert_not_called ()
740839
741840 @pytest .mark .asyncio
742841 async def test_query_task_executes_registered_query (self , mock_client : AsyncMock ) -> None :
0 commit comments