Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infra/scripts/assign_azure_ai_user_role.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ IFS=',' read -r -a principal_ids_array <<< $principal_ids

echo "Assigning Foundry User role to users"

echo "Using provided Azure AI resource id: $aif_resource_id"
echo "Using provided Foundry resource id: $aif_resource_id"

for principal_id in "${principal_ids_array[@]}"; do

Expand Down
2 changes: 2 additions & 0 deletions src/App/src/hooks/usePlanWebSocket.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ export function usePlanWebSocket({
dispatch(addAgentMessage(errorAgent));
dispatch(planFailedFinal());
dispatch(setShowBufferingText(false));
dispatch(setSubmittingChatDisableInput(true));
scrollToBottom();
showToast(errorContent, 'error');
webSocketService.disconnect();
Expand Down Expand Up @@ -254,6 +255,7 @@ export function usePlanWebSocket({
dispatch(addAgentMessage(errorAgent));
dispatch(planFailedFinal());
dispatch(setShowBufferingText(false));
dispatch(setSubmittingChatDisableInput(true));
scrollToBottom();
showToast(errorContent, 'error');
webSocketService.disconnect();
Expand Down
7 changes: 3 additions & 4 deletions src/backend/v4/orchestration/orchestration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

from common.config.app_config import config
from common.models.messages_af import TeamConfiguration
from common.models.messages_af import TeamConfiguration, PlanStatus

from common.database.database_base import DatabaseBase

Expand All @@ -39,7 +39,6 @@
from v4.orchestration.human_approval_manager import HumanApprovalMagenticManager
from v4.magentic_agents.magentic_agent_factory import MagenticAgentFactory
from common.database.database_factory import DatabaseFactory
from v4.models.models import PlanStatus


class OrchestrationManager:
Expand Down Expand Up @@ -296,7 +295,7 @@ async def get_current_or_new_orchestration(
# ---------------------------
# Execution
# ---------------------------
async def run_orchestration(self, user_id: str, input_task, plan_id: str = None) -> None:
async def run_orchestration(self, user_id: str, input_task, plan_id: Optional[str] = None) -> None:
"""
Execute the Magentic workflow for the provided user and task description.
"""
Expand Down Expand Up @@ -574,7 +573,7 @@ async def run_orchestration(self, user_id: str, input_task, plan_id: str = None)
memory_store = await DatabaseFactory.get_database(user_id=user_id)
plan = await memory_store.get_plan_by_plan_id(plan_id=self._plan_id)
if plan:
plan.overall_status = PlanStatus.FAILED
plan.overall_status = PlanStatus.failed
await memory_store.update_plan(plan)
self.logger.info("Plan '%s' status updated to FAILED", self._plan_id)
except Exception as db_error:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,17 @@ class MockPlanStatus:
FAILED = "failed"
COMPLETED = "completed"
IN_PROGRESS = "in_progress"
# messages_af.PlanStatus uses lowercase member names
failed = "failed"
completed = "completed"
in_progress = "in_progress"

sys.modules['v4.models'] = Mock()
sys.modules['v4.models.messages'] = Mock(WebsocketMessageType=MockWebsocketMessageType)
sys.modules['v4.models.models'] = Mock(PlanStatus=MockPlanStatus)
# Attach PlanStatus to the already-mocked messages_af module (production code now imports
# PlanStatus from common.models.messages_af, not v4.models.models).
sys.modules['common.models.messages_af'].PlanStatus = MockPlanStatus

# Mock v4.orchestration.human_approval_manager
class MockHumanApprovalMagenticManager:
Expand Down Expand Up @@ -933,6 +940,82 @@ async def test_run_orchestration_all_event_types(self):
# Verify streaming callback was called (for output event with AgentResponseUpdate data)
streaming_agent_response_callback.assert_called()

async def test_run_orchestration_marks_plan_failed_on_exception(self):
"""When orchestration raises and plan_id is set, plan.overall_status must be
updated to FAILED via DatabaseFactory/get_plan_by_plan_id/update_plan."""
mock_workflow = Mock()
mock_workflow.executors = {}
mock_workflow.run = Mock(side_effect=Exception("Workflow execution failed"))
orchestration_config.get_current_orchestration.return_value = mock_workflow

mock_plan = Mock()
mock_plan.overall_status = "in_progress"
mock_memory_store = Mock()
mock_memory_store.get_plan_by_plan_id = AsyncMock(return_value=mock_plan)
mock_memory_store.update_plan = AsyncMock()

db_factory_mock = sys.modules['common.database.database_factory'].DatabaseFactory
db_factory_mock.get_database = AsyncMock(return_value=mock_memory_store)

input_task = Mock()
input_task.description = "Test task"

with self.assertRaises(Exception):
await self.orchestration_manager.run_orchestration(
user_id=self.test_user_id,
input_task=input_task,
plan_id="plan-123",
)

db_factory_mock.get_database.assert_awaited_with(user_id=self.test_user_id)
mock_memory_store.get_plan_by_plan_id.assert_awaited_with(plan_id="plan-123")
mock_memory_store.update_plan.assert_awaited_once()
self.assertEqual(mock_plan.overall_status, "failed")

async def test_run_orchestration_db_failure_does_not_mask_original_error(self):
"""If the DB update itself fails, the original orchestration error must still
propagate (the DB error is logged and swallowed)."""
mock_workflow = Mock()
mock_workflow.executors = {}
original_error = RuntimeError("Workflow boom")
mock_workflow.run = Mock(side_effect=original_error)
orchestration_config.get_current_orchestration.return_value = mock_workflow

db_factory_mock = sys.modules['common.database.database_factory'].DatabaseFactory
db_factory_mock.get_database = AsyncMock(side_effect=Exception("DB unavailable"))

input_task = Mock()
input_task.description = "Test task"

with self.assertRaises(RuntimeError) as ctx:
await self.orchestration_manager.run_orchestration(
user_id=self.test_user_id,
input_task=input_task,
plan_id="plan-123",
)
self.assertIn("Workflow boom", str(ctx.exception))

async def test_run_orchestration_skips_db_update_when_no_plan_id(self):
"""When plan_id is not provided, the orchestration must not touch the DB on failure."""
mock_workflow = Mock()
mock_workflow.executors = {}
mock_workflow.run = Mock(side_effect=Exception("Workflow execution failed"))
orchestration_config.get_current_orchestration.return_value = mock_workflow

db_factory_mock = sys.modules['common.database.database_factory'].DatabaseFactory
db_factory_mock.get_database = AsyncMock()

input_task = Mock()
input_task.description = "Test task"

with self.assertRaises(Exception):
await self.orchestration_manager.run_orchestration(
user_id=self.test_user_id,
input_task=input_task,
)

db_factory_mock.get_database.assert_not_awaited()


class TestExtractResponseText(IsolatedAsyncioTestCase):
"""Test _extract_response_text method for various input types."""
Expand Down