@@ -324,10 +324,17 @@ class MockPlanStatus:
324324 FAILED = "failed"
325325 COMPLETED = "completed"
326326 IN_PROGRESS = "in_progress"
327+ # messages_af.PlanStatus uses lowercase member names
328+ failed = "failed"
329+ completed = "completed"
330+ in_progress = "in_progress"
327331
328332sys .modules ['v4.models' ] = Mock ()
329333sys .modules ['v4.models.messages' ] = Mock (WebsocketMessageType = MockWebsocketMessageType )
330334sys .modules ['v4.models.models' ] = Mock (PlanStatus = MockPlanStatus )
335+ # Attach PlanStatus to the already-mocked messages_af module (production code now imports
336+ # PlanStatus from common.models.messages_af, not v4.models.models).
337+ sys .modules ['common.models.messages_af' ].PlanStatus = MockPlanStatus
331338
332339# Mock v4.orchestration.human_approval_manager
333340class MockHumanApprovalMagenticManager :
@@ -933,6 +940,82 @@ async def test_run_orchestration_all_event_types(self):
933940 # Verify streaming callback was called (for output event with AgentResponseUpdate data)
934941 streaming_agent_response_callback .assert_called ()
935942
943+ async def test_run_orchestration_marks_plan_failed_on_exception (self ):
944+ """When orchestration raises and plan_id is set, plan.overall_status must be
945+ updated to FAILED via DatabaseFactory/get_plan_by_plan_id/update_plan."""
946+ mock_workflow = Mock ()
947+ mock_workflow .executors = {}
948+ mock_workflow .run = Mock (side_effect = Exception ("Workflow execution failed" ))
949+ orchestration_config .get_current_orchestration .return_value = mock_workflow
950+
951+ mock_plan = Mock ()
952+ mock_plan .overall_status = "in_progress"
953+ mock_memory_store = Mock ()
954+ mock_memory_store .get_plan_by_plan_id = AsyncMock (return_value = mock_plan )
955+ mock_memory_store .update_plan = AsyncMock ()
956+
957+ db_factory_mock = sys .modules ['common.database.database_factory' ].DatabaseFactory
958+ db_factory_mock .get_database = AsyncMock (return_value = mock_memory_store )
959+
960+ input_task = Mock ()
961+ input_task .description = "Test task"
962+
963+ with self .assertRaises (Exception ):
964+ await self .orchestration_manager .run_orchestration (
965+ user_id = self .test_user_id ,
966+ input_task = input_task ,
967+ plan_id = "plan-123" ,
968+ )
969+
970+ db_factory_mock .get_database .assert_awaited_with (user_id = self .test_user_id )
971+ mock_memory_store .get_plan_by_plan_id .assert_awaited_with (plan_id = "plan-123" )
972+ mock_memory_store .update_plan .assert_awaited_once ()
973+ self .assertEqual (mock_plan .overall_status , "failed" )
974+
975+ async def test_run_orchestration_db_failure_does_not_mask_original_error (self ):
976+ """If the DB update itself fails, the original orchestration error must still
977+ propagate (the DB error is logged and swallowed)."""
978+ mock_workflow = Mock ()
979+ mock_workflow .executors = {}
980+ original_error = RuntimeError ("Workflow boom" )
981+ mock_workflow .run = Mock (side_effect = original_error )
982+ orchestration_config .get_current_orchestration .return_value = mock_workflow
983+
984+ db_factory_mock = sys .modules ['common.database.database_factory' ].DatabaseFactory
985+ db_factory_mock .get_database = AsyncMock (side_effect = Exception ("DB unavailable" ))
986+
987+ input_task = Mock ()
988+ input_task .description = "Test task"
989+
990+ with self .assertRaises (RuntimeError ) as ctx :
991+ await self .orchestration_manager .run_orchestration (
992+ user_id = self .test_user_id ,
993+ input_task = input_task ,
994+ plan_id = "plan-123" ,
995+ )
996+ self .assertIn ("Workflow boom" , str (ctx .exception ))
997+
998+ async def test_run_orchestration_skips_db_update_when_no_plan_id (self ):
999+ """When plan_id is not provided, the orchestration must not touch the DB on failure."""
1000+ mock_workflow = Mock ()
1001+ mock_workflow .executors = {}
1002+ mock_workflow .run = Mock (side_effect = Exception ("Workflow execution failed" ))
1003+ orchestration_config .get_current_orchestration .return_value = mock_workflow
1004+
1005+ db_factory_mock = sys .modules ['common.database.database_factory' ].DatabaseFactory
1006+ db_factory_mock .get_database = AsyncMock ()
1007+
1008+ input_task = Mock ()
1009+ input_task .description = "Test task"
1010+
1011+ with self .assertRaises (Exception ):
1012+ await self .orchestration_manager .run_orchestration (
1013+ user_id = self .test_user_id ,
1014+ input_task = input_task ,
1015+ )
1016+
1017+ db_factory_mock .get_database .assert_not_awaited ()
1018+
9361019
9371020class TestExtractResponseText (IsolatedAsyncioTestCase ):
9381021 """Test _extract_response_text method for various input types."""
0 commit comments