diff --git a/infra/vscode_web/requirements.txt b/infra/vscode_web/requirements.txt index 272ff37b1..d7ff98e4b 100644 --- a/infra/vscode_web/requirements.txt +++ b/infra/vscode_web/requirements.txt @@ -1,3 +1,3 @@ -azure-ai-projects==2.0.0b3 +azure-ai-projects==2.1.0 azure-identity==1.20.0 ansible-core~=2.17.0 \ No newline at end of file diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py index b150f00cf..c31122259 100644 --- a/src/backend/orchestrator.py +++ b/src/backend/orchestrator.py @@ -22,14 +22,12 @@ from typing import AsyncIterator, Optional, cast from agent_framework import ( - ChatMessage, - HandoffBuilder, - HandoffAgentUserRequest, - RequestInfoEvent, - WorkflowOutputEvent, - WorkflowStatusEvent, + Agent, + Message, + WorkflowEventType, ) -from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.orchestrations import HandoffBuilder, HandoffAgentUserRequest +from agent_framework.openai import OpenAIChatCompletionClient from azure.identity import DefaultAzureCredential # Foundry imports - only used when USE_FOUNDRY=true @@ -48,6 +46,11 @@ # Token endpoint for Azure Cognitive Services (used for Azure OpenAI) TOKEN_ENDPOINT = "https://cognitiveservices.azure.com/.default" +# Event type constants for type-safe dispatch (avoids string typos) +EVENT_STATUS: WorkflowEventType = "status" +EVENT_REQUEST_INFO: WorkflowEventType = "request_info" +EVENT_OUTPUT: WorkflowEventType = "output" + # Harmful content patterns to detect in USER INPUT before processing # This provides proactive content safety by blocking harmful requests at the input layer @@ -120,9 +123,9 @@ def _check_input_for_harmful_content(message: str) -> tuple[bool, str]: r"You are a Text Content Agent", r"You are an Image Content Agent", r"You are a Compliance Agent", - # Handoff instructions - r"hand off to \w+_agent", - r"hand back to \w+_agent", + # Handoff instructions (match both underscore and hyphen agent names) + r"hand off to [\w\-]+[_\-]agent", + r"hand back to [\w\-]+[_\-]agent", r"may hand off to", r"After (?:generating|completing|validation|parsing)", # Internal workflow markers @@ -139,8 +142,8 @@ def _check_input_for_harmful_content(message: str) -> tuple[bool, str]: # RAI internal instructions r"NEVER generate images that contain:", r"Responsible AI - Image Generation Rules", - # Agent framework references - r"compliance_agent|triage_agent|planning_agent|research_agent|text_content_agent|image_content_agent", + # Agent framework references (match both underscore and hyphen separators) + r"compliance[_\-]agent|triage[_\-]agent|planning[_\-]agent|research[_\-]agent|text[_\-]content[_\-]agent|image[_\-]content[_\-]agent", ] _SYSTEM_PROMPT_PATTERNS_COMPILED = [re.compile(pattern, re.IGNORECASE | re.DOTALL) for pattern in SYSTEM_PROMPT_PATTERNS] @@ -485,7 +488,7 @@ class ContentGenerationOrchestrator: Microsoft Agent Framework's HandoffBuilder. Supports two modes: - 1. Azure OpenAI Direct (default): Uses AzureOpenAIChatClient with ad_token_provider + 1. Azure OpenAI Direct (default): Uses OpenAIChatCompletionClient with DefaultAzureCredential 2. Azure AI Foundry: Uses AIProjectClient with project endpoint (set USE_FOUNDRY=true) Agents: @@ -498,7 +501,7 @@ class ContentGenerationOrchestrator: """ def __init__(self): - self._chat_client = None # Always AzureOpenAIChatClient + self._chat_client = None # OpenAIChatCompletionClient instance self._project_client = None # AIProjectClient for Foundry mode (used for image generation) self._agents: dict = {} self._rai_agent = None @@ -536,27 +539,21 @@ def _get_chat_client(self): # Store the project client for image generation self._project_client = project_client - # For chat completions, use the direct Azure OpenAI endpoint # The Foundry project uses Azure OpenAI under the hood, and we need the direct endpoint # to properly authenticate with Cognitive Services token azure_endpoint = app_settings.azure_openai.endpoint if not azure_endpoint: raise ValueError("AZURE_OPENAI_ENDPOINT is required for Foundry mode chat completions") - def get_token() -> str: - """Token provider callable - invoked for each request to ensure fresh tokens.""" - token = self._credential.get_token(TOKEN_ENDPOINT) - return token.token - model_deployment = app_settings.ai_foundry.model_deployment or app_settings.azure_openai.gpt_model api_version = app_settings.azure_openai.api_version logger.info(f"Foundry mode using Azure OpenAI endpoint: {azure_endpoint}, deployment: {model_deployment}") - self._chat_client = AzureOpenAIChatClient( - endpoint=azure_endpoint, - deployment_name=model_deployment, + self._chat_client = OpenAIChatCompletionClient( + azure_endpoint=azure_endpoint, + model=model_deployment, api_version=api_version, - ad_token_provider=get_token, + credential=self._credential, ) else: # Azure OpenAI Direct mode @@ -564,17 +561,12 @@ def get_token() -> str: if not endpoint: raise ValueError("AZURE_OPENAI_ENDPOINT is not configured") - def get_token() -> str: - """Token provider callable - invoked for each request to ensure fresh tokens.""" - token = self._credential.get_token(TOKEN_ENDPOINT) - return token.token - - logger.info("Using Azure OpenAI Direct mode with ad_token_provider") - self._chat_client = AzureOpenAIChatClient( - endpoint=endpoint, - deployment_name=app_settings.azure_openai.gpt_model, + logger.info("Using Azure OpenAI Direct mode with DefaultAzureCredential") + self._chat_client = OpenAIChatCompletionClient( + azure_endpoint=endpoint, + model=app_settings.azure_openai.gpt_model, api_version=app_settings.azure_openai.api_version, - ad_token_provider=get_token, + credential=self._credential, ) return self._chat_client @@ -589,40 +581,60 @@ def initialize(self) -> None: # Get the chat client chat_client = self._get_chat_client() - # Agent names - use underscores (AzureOpenAIChatClient works with both modes now) + # Agent names - always use underscores so that instruction strings + # (TRIAGE_INSTRUCTIONS, *_CONTENT_INSTRUCTIONS, etc.) and the + # SYSTEM_PROMPT_PATTERNS leakage-detection regexes stay in sync. + # Foundry workflows accept underscore names; no hyphen conversion needed. name_sep = "_" # Create all agents - triage_agent = chat_client.create_agent( + # NOTE: Handoff workflow participants must set + # require_per_service_call_history_persistence=True so local conversation + # history stays consistent with the service across handoff tool-call + # short-circuits (required by agent_framework.orchestrations.HandoffBuilder). + triage_agent = Agent( + client=chat_client, name=f"triage{name_sep}agent", instructions=TRIAGE_INSTRUCTIONS, + require_per_service_call_history_persistence=True, ) - planning_agent = chat_client.create_agent( + planning_agent = Agent( + client=chat_client, name=f"planning{name_sep}agent", instructions=PLANNING_INSTRUCTIONS, + require_per_service_call_history_persistence=True, ) - research_agent = chat_client.create_agent( + research_agent = Agent( + client=chat_client, name=f"research{name_sep}agent", instructions=RESEARCH_INSTRUCTIONS, + require_per_service_call_history_persistence=True, ) - text_content_agent = chat_client.create_agent( + text_content_agent = Agent( + client=chat_client, name=f"text{name_sep}content{name_sep}agent", instructions=TEXT_CONTENT_INSTRUCTIONS, + require_per_service_call_history_persistence=True, ) - image_content_agent = chat_client.create_agent( + image_content_agent = Agent( + client=chat_client, name=f"image{name_sep}content{name_sep}agent", instructions=IMAGE_CONTENT_INSTRUCTIONS, + require_per_service_call_history_persistence=True, ) - compliance_agent = chat_client.create_agent( + compliance_agent = Agent( + client=chat_client, name=f"compliance{name_sep}agent", instructions=COMPLIANCE_INSTRUCTIONS, + require_per_service_call_history_persistence=True, ) - self._rai_agent = chat_client.create_agent( + self._rai_agent = Agent( + client=chat_client, name=f"rai{name_sep}agent", instructions=RAI_INSTRUCTIONS, ) @@ -636,7 +648,7 @@ def initialize(self) -> None: "compliance": compliance_agent, } - # Workflow name - Foundry requires hyphens + # Workflow name workflow_name = f"content{name_sep}generation{name_sep}workflow" # Build the handoff workflow @@ -736,21 +748,21 @@ async def process_message( events.append(event) # Handle different event types from the workflow - if isinstance(event, WorkflowStatusEvent): + if event.type == EVENT_STATUS: + status_name = event.state.name if event.state else str(event.data) yield { "type": "status", - "content": event.state.name, + "content": status_name, "is_final": False, "metadata": {"conversation_id": conversation_id} } - elif isinstance(event, RequestInfoEvent): + elif event.type == EVENT_REQUEST_INFO: # Workflow is requesting user input if isinstance(event.data, HandoffAgentUserRequest): - # Extract conversation history from agent_response.messages (updated API) - messages = event.data.agent_response.messages if hasattr(event.data, 'agent_response') and event.data.agent_response else [] - if not isinstance(messages, list): - messages = [messages] if messages else [] + # Extract conversation history from agent_response.messages + agent_resp = event.data.agent_response + messages = list(agent_resp.messages) if agent_resp and agent_resp.messages else [] conversation_text = "\n".join([ f"{msg.author_name or msg.role.value}: {msg.text}" @@ -758,13 +770,13 @@ async def process_message( ]) # Get the last message content and filter any system prompt leakage - last_msg_content = messages[-1].text if messages else (event.data.agent_response.text if hasattr(event.data, 'agent_response') and event.data.agent_response else "") + last_msg_content = messages[-1].text if messages else (agent_resp.text if agent_resp else "") last_msg_content = _filter_system_prompt_from_response(last_msg_content) - last_msg_agent = messages[-1].author_name if messages and hasattr(messages[-1], 'author_name') else "unknown" + last_msg_agent = messages[-1].author_name if messages else "unknown" yield { "type": "agent_response", - "agent": last_msg_agent, + "agent": last_msg_agent or "unknown", "content": last_msg_content, "conversation_history": conversation_text, "is_final": False, @@ -773,9 +785,8 @@ async def process_message( "metadata": {"conversation_id": conversation_id} } - elif isinstance(event, WorkflowOutputEvent): - # Final output from the workflow - conversation = cast(list[ChatMessage], event.data) + elif event.type == EVENT_OUTPUT: + conversation = cast(list[Message], event.data) if isinstance(conversation, list) and conversation: # Get the last assistant message as the final response assistant_messages = [ @@ -841,29 +852,29 @@ async def send_user_response( try: responses = {request_id: user_response} async for event in self._workflow.send_responses_streaming(responses): - if isinstance(event, WorkflowStatusEvent): + if event.type == EVENT_STATUS: + status_name = event.state.name if event.state else str(event.data) yield { "type": "status", - "content": event.state.name, + "content": status_name, "is_final": False, "metadata": {"conversation_id": conversation_id} } - elif isinstance(event, RequestInfoEvent): + elif event.type == EVENT_REQUEST_INFO: if isinstance(event.data, HandoffAgentUserRequest): - # Get messages from agent_response (updated API) - messages = event.data.agent_response.messages if hasattr(event.data, 'agent_response') and event.data.agent_response else [] - if not isinstance(messages, list): - messages = [messages] if messages else [] + # Get messages from agent_response + agent_resp = event.data.agent_response + messages = list(agent_resp.messages) if agent_resp and agent_resp.messages else [] # Get the last message content and filter any system prompt leakage - last_msg_content = messages[-1].text if messages else (event.data.agent_response.text if hasattr(event.data, 'agent_response') and event.data.agent_response else "") + last_msg_content = messages[-1].text if messages else (agent_resp.text if agent_resp else "") last_msg_content = _filter_system_prompt_from_response(last_msg_content) - last_msg_agent = messages[-1].author_name if messages and hasattr(messages[-1], 'author_name') else "unknown" + last_msg_agent = messages[-1].author_name if messages else "unknown" yield { "type": "agent_response", - "agent": last_msg_agent, + "agent": last_msg_agent or "unknown", "content": last_msg_content, "is_final": False, "requires_user_input": True, @@ -871,8 +882,8 @@ async def send_user_response( "metadata": {"conversation_id": conversation_id} } - elif isinstance(event, WorkflowOutputEvent): - conversation = cast(list[ChatMessage], event.data) + elif event.type == EVENT_OUTPUT: + conversation = cast(list[Message], event.data) if isinstance(conversation, list) and conversation: assistant_messages = [ msg for msg in conversation diff --git a/src/backend/requirements.txt b/src/backend/requirements.txt index fda72578f..0def70e41 100644 --- a/src/backend/requirements.txt +++ b/src/backend/requirements.txt @@ -6,8 +6,9 @@ quart-cors>=0.7.0 hypercorn>=0.17.0 # Microsoft Agent Framework -agent-framework-azure-ai==1.0.0b260114 -agent-framework-core==1.0.0b260114 +agent-framework-foundry==1.1.1 +agent-framework-core==1.1.1 +agent-framework-orchestrations==1.0.0b260421 # OpenTelemetry (required by agent-framework) opentelemetry-semantic-conventions-ai==0.4.13 @@ -18,7 +19,7 @@ azure-cosmos>=4.7.0 azure-storage-blob>=12.22.0 azure-search-documents>=11.4.0 azure-ai-contentsafety>=1.0.0 -azure-ai-projects==2.0.0b3 # Azure AI Foundry SDK (optional, for USE_FOUNDRY=true) +azure-ai-projects==2.1.0 # Azure AI Foundry SDK (optional, for USE_FOUNDRY=true) # OpenAI openai>=1.45.0 diff --git a/src/backend/services/title_service.py b/src/backend/services/title_service.py index e849ca22d..92289ef12 100644 --- a/src/backend/services/title_service.py +++ b/src/backend/services/title_service.py @@ -9,16 +9,14 @@ import re from typing import Optional -from agent_framework.azure import AzureOpenAIChatClient +from agent_framework import Agent +from agent_framework.openai import OpenAIChatCompletionClient from azure.identity import DefaultAzureCredential from settings import app_settings logger = logging.getLogger(__name__) -# Token endpoint for Azure OpenAI authentication -TOKEN_ENDPOINT = "https://cognitiveservices.azure.com/.default" - # Title generation instructions (from MS reference accelerator) TITLE_INSTRUCTIONS = """Summarize the conversation so far into a 4-word or less title. Do not use any quotation marks or punctuation. @@ -57,20 +55,15 @@ def initialize(self) -> None: api_version = app_settings.azure_openai.api_version - # Create token provider function - def get_token() -> str: - """Token provider callable - invoked for each request to ensure fresh tokens.""" - token = self._credential.get_token(TOKEN_ENDPOINT) - return token.token - - chat_client = AzureOpenAIChatClient( - endpoint=endpoint, - deployment_name=deployment, + chat_client = OpenAIChatCompletionClient( + azure_endpoint=endpoint, + model=deployment, api_version=api_version, - ad_token_provider=get_token, + credential=self._credential, ) - self._agent = chat_client.create_agent( + self._agent = Agent( + client=chat_client, name="title_agent", instructions=TITLE_INSTRUCTIONS, ) diff --git a/src/tests/services/test_orchestrator.py b/src/tests/services/test_orchestrator.py index 6063dbd2c..bc7ad7dd7 100644 --- a/src/tests/services/test_orchestrator.py +++ b/src/tests/services/test_orchestrator.py @@ -135,6 +135,24 @@ def test_filter_system_prompt_handoff(): assert "text_content_agent" not in filtered +def test_filter_system_prompt_handoff_hyphenated(): + """Test filtering of handoff instructions with hyphenated agent names.""" + + response = "I'll hand off to text-content-agent now" + filtered = _filter_system_prompt_from_response(response) + + assert "text-content-agent" not in filtered + + +def test_filter_system_prompt_handback_hyphenated(): + """Test filtering of hand back instructions with hyphenated agent names.""" + + response = "Let me hand back to triage-agent with results" + filtered = _filter_system_prompt_from_response(response) + + assert "triage-agent" not in filtered + + def test_filter_system_prompt_critical(): """Test filtering of critical instruction markers.""" @@ -238,7 +256,8 @@ async def test_orchestrator_initialize_creates_workflow(): """Test that initialize creates the workflow.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -254,13 +273,16 @@ async def test_orchestrator_initialize_creates_workflow(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() mock_builder_instance = MagicMock() + mock_builder_instance.participants.return_value = mock_builder_instance + mock_builder_instance.with_start_agent.return_value = mock_builder_instance mock_builder_instance.add_agent.return_value = mock_builder_instance mock_builder_instance.add_handoff.return_value = mock_builder_instance + mock_builder_instance.with_termination_condition.return_value = mock_builder_instance mock_builder_instance.build.return_value = mock_workflow mock_builder.return_value = mock_builder_instance @@ -276,7 +298,8 @@ async def test_orchestrator_initialize_foundry_mode(): """Test orchestrator in foundry mode.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder, \ patch("orchestrator.FOUNDRY_AVAILABLE", True), \ patch("orchestrator.AIProjectClient"): @@ -296,13 +319,16 @@ async def test_orchestrator_initialize_foundry_mode(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() mock_builder_instance = MagicMock() + mock_builder_instance.participants.return_value = mock_builder_instance + mock_builder_instance.with_start_agent.return_value = mock_builder_instance mock_builder_instance.add_agent.return_value = mock_builder_instance mock_builder_instance.add_handoff.return_value = mock_builder_instance + mock_builder_instance.with_termination_condition.return_value = mock_builder_instance mock_builder_instance.build.return_value = mock_workflow mock_builder.return_value = mock_builder_instance @@ -339,7 +365,8 @@ async def test_process_message_safe_content(): """Test that process_message allows safe content.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -355,22 +382,23 @@ async def test_process_message_safe_content(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Create async generator for workflow.run_stream - # WorkflowOutputEvent.data should be a list of ChatMessage objects + # Event with type="output" and data as list of Message objects async def mock_stream(*_args, **_kwargs): - from agent_framework import WorkflowOutputEvent + # Create mock event with type attribute - # Create a mock ChatMessage with expected attributes + # Create a mock Message with expected attributes mock_message = MagicMock() mock_message.role.value = "assistant" mock_message.text = "Here's your marketing content" mock_message.author_name = "content_agent" - # Use real WorkflowOutputEvent so isinstance() check passes - event = WorkflowOutputEvent(data=[mock_message], source_executor_id="test") + event = MagicMock() + event.type = "output" + event.data = [mock_message] yield event mock_workflow = MagicMock() @@ -424,7 +452,8 @@ async def test_parse_brief_complete(): """Test parse_brief with complete brief data.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -440,7 +469,7 @@ async def test_parse_brief_complete(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Mock planning agent response @@ -513,7 +542,8 @@ async def test_select_products_add_action(): """Test select_products with add action.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -529,7 +559,7 @@ async def test_select_products_add_action(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_research_agent = AsyncMock() @@ -564,7 +594,8 @@ async def test_select_products_json_error(): """Test select_products handles JSON parsing errors.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -580,7 +611,7 @@ async def test_select_products_json_error(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_research_agent = AsyncMock() @@ -611,7 +642,8 @@ async def test_generate_content_text_only(): """Test generate_content without images.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder, \ patch("orchestrator._check_input_for_harmful_content") as mock_check: @@ -630,7 +662,7 @@ async def test_generate_content_text_only(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_text_agent = AsyncMock() @@ -669,7 +701,8 @@ async def test_generate_content_with_compliance_violations(): """Test generate_content with compliance violations.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder, \ patch("orchestrator._check_input_for_harmful_content") as mock_check: @@ -688,7 +721,7 @@ async def test_generate_content_with_compliance_violations(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_text_agent = AsyncMock() @@ -814,7 +847,8 @@ def test_get_orchestrator_singleton(): """Test that get_orchestrator returns singleton instance.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -830,7 +864,7 @@ def test_get_orchestrator_singleton(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1005,7 +1039,8 @@ async def test_process_message_empty_events(): """Test process_message with workflow returning no events.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1021,7 +1056,7 @@ async def test_process_message_empty_events(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client async def empty_stream(*_args, **_kwargs): @@ -1053,7 +1088,8 @@ async def test_parse_brief_rai_agent_blocks(): """Test parse_brief when RAI agent returns TRUE (blocked).""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1069,7 +1105,7 @@ async def test_parse_brief_rai_agent_blocks(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1098,7 +1134,8 @@ async def test_parse_brief_rai_agent_exception(): """Test parse_brief continues when RAI agent raises exception.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1114,7 +1151,7 @@ async def test_parse_brief_rai_agent_exception(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1148,7 +1185,8 @@ async def test_parse_brief_incomplete_fields(): """Test parse_brief with incomplete brief returns clarifying message.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1164,7 +1202,7 @@ async def test_parse_brief_incomplete_fields(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1204,7 +1242,8 @@ async def test_parse_brief_json_in_code_block(): """Test parse_brief extracts JSON from markdown code blocks.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1220,7 +1259,7 @@ async def test_parse_brief_json_in_code_block(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1258,7 +1297,8 @@ async def test_generate_content_text_content(): """Test generate_content produces text content.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1274,7 +1314,7 @@ async def test_generate_content_text_content(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1325,7 +1365,8 @@ async def test_regenerate_image_foundry_mode(): """Test regenerate_image in Foundry mode.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = True @@ -1344,7 +1385,7 @@ async def test_regenerate_image_foundry_mode(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1382,7 +1423,8 @@ async def test_regenerate_image_exception(): """Test regenerate_image handles exceptions gracefully.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = True @@ -1401,7 +1443,7 @@ async def test_regenerate_image_exception(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_workflow = MagicMock() @@ -1489,7 +1531,7 @@ async def test_get_chat_client_foundry_mode(): """Test _get_chat_client in Foundry mode.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ patch("orchestrator.FOUNDRY_AVAILABLE", True): mock_settings.ai_foundry.use_foundry = True @@ -1531,7 +1573,7 @@ async def test_process_message_with_context(): """Test process_message with context parameter.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client: + patch("orchestrator.OpenAIChatCompletionClient") as mock_client: mock_settings.ai_foundry.use_foundry = False mock_settings.azure_openai.endpoint = "https://test.openai.azure.com" @@ -1546,7 +1588,7 @@ async def test_process_message_with_context(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Track if workflow was called @@ -1586,7 +1628,7 @@ async def test_send_user_response_safe_content(): """Test send_user_response allows safe content through.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client: + patch("orchestrator.OpenAIChatCompletionClient") as mock_client: mock_settings.ai_foundry.use_foundry = False mock_settings.azure_openai.endpoint = "https://test.openai.azure.com" @@ -1601,7 +1643,7 @@ async def test_send_user_response_safe_content(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client call_tracker = {"called": False, "responses": None} @@ -1637,7 +1679,8 @@ async def test_parse_brief_json_with_backticks(): """Test parse_brief extracting JSON from ```json blocks.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1653,7 +1696,7 @@ async def test_parse_brief_json_with_backticks(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Mock planning agent to return JSON in ```json block @@ -1705,7 +1748,8 @@ async def test_parse_brief_with_dict_field_value(): """Test parse_brief handles dict values in extracted_fields.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1721,7 +1765,7 @@ async def test_parse_brief_with_dict_field_value(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Mock planning agent with dict field values (line 1031) @@ -1777,7 +1821,8 @@ async def test_parse_brief_fallback_extraction(): """Test parse_brief falls back to _extract_brief_from_text on parse error.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = False @@ -1793,7 +1838,7 @@ async def test_parse_brief_fallback_extraction(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Mock planning agent with invalid JSON @@ -2059,7 +2104,8 @@ async def test_generate_content_with_foundry_image(): """Test generate_content generates images in Foundry mode.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder: mock_settings.ai_foundry.use_foundry = True @@ -2076,7 +2122,7 @@ async def test_generate_content_with_foundry_image(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client # Mock agents @@ -2130,7 +2176,8 @@ async def test_generate_content_direct_mode_image(): """Test generate_content generates images in Direct mode.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder, \ patch("agents.image_content_agent.generate_image") as mock_generate_image: @@ -2147,7 +2194,7 @@ async def test_generate_content_direct_mode_image(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_text_agent = AsyncMock() @@ -2210,7 +2257,8 @@ async def test_regenerate_image_direct_mode(): """Test regenerate_image in Direct mode.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder, \ patch("agents.image_content_agent.generate_image") as mock_generate_image: @@ -2227,7 +2275,7 @@ async def test_regenerate_image_direct_mode(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_image_agent = AsyncMock() @@ -2285,7 +2333,8 @@ async def test_regenerate_image_failure(): """Test regenerate_image handles generation failure.""" with patch("orchestrator.app_settings") as mock_settings, \ patch("orchestrator.DefaultAzureCredential") as mock_cred, \ - patch("orchestrator.AzureOpenAIChatClient") as mock_client, \ + patch("orchestrator.OpenAIChatCompletionClient") as mock_client, \ + patch("orchestrator.Agent") as mock_agent_cls, \ patch("orchestrator.HandoffBuilder") as mock_builder, \ patch("agents.image_content_agent.generate_image") as mock_generate_image: @@ -2302,7 +2351,7 @@ async def test_regenerate_image_failure(): mock_cred.return_value = mock_credential mock_chat_client = MagicMock() - mock_chat_client.create_agent.return_value = MagicMock() + # Agent is now instantiated directly (not via chat_client.create_agent) mock_client.return_value = mock_chat_client mock_image_agent = AsyncMock() diff --git a/src/tests/test_title_generation_service.py b/src/tests/test_title_generation_service.py index caf051e54..efcacb67c 100644 --- a/src/tests/test_title_generation_service.py +++ b/src/tests/test_title_generation_service.py @@ -178,3 +178,67 @@ def test_returns_existing_instance(self, mock_existing): mock_existing.__bool__ = lambda self: True result = get_title_service() assert result is mock_existing + + +# --------------------------------------------------------------------------- +# TitleService.initialize wiring +# --------------------------------------------------------------------------- + + +class TestTitleServiceInitialize: + """Tests that initialize wires the chat client with correct config.""" + + @patch("services.title_service.app_settings") + @patch("services.title_service.DefaultAzureCredential") + @patch("services.title_service.OpenAIChatCompletionClient") + @patch("services.title_service.Agent") + def test_initialize_wires_credential_direct_mode( + self, mock_agent, mock_client, mock_cred_cls, mock_settings + ): + """Test that initialize passes credential directly to chat client.""" + mock_credential = MagicMock() + mock_cred_cls.return_value = mock_credential + + mock_settings.ai_foundry.use_foundry = False + mock_settings.azure_openai.endpoint = "https://test.openai.azure.com" + mock_settings.azure_openai.gpt_model = "gpt-4o" + mock_settings.azure_openai.api_version = "2024-02-15" + + svc = TitleService() + svc.initialize() + + mock_client.assert_called_once_with( + azure_endpoint="https://test.openai.azure.com", + model="gpt-4o", + api_version="2024-02-15", + credential=mock_credential, + ) + assert svc._initialized is True + + @patch("services.title_service.app_settings") + @patch("services.title_service.DefaultAzureCredential") + @patch("services.title_service.OpenAIChatCompletionClient") + @patch("services.title_service.Agent") + def test_initialize_wires_credential_foundry_mode( + self, mock_agent, mock_client, mock_cred_cls, mock_settings + ): + """Test that initialize uses Foundry endpoint and model.""" + mock_credential = MagicMock() + mock_cred_cls.return_value = mock_credential + + mock_settings.ai_foundry.use_foundry = True + mock_settings.azure_openai.endpoint = "https://foundry.openai.azure.com" + mock_settings.ai_foundry.model_deployment = "gpt-4o-foundry" + mock_settings.azure_openai.gpt_model = "gpt-4o" + mock_settings.azure_openai.api_version = "2024-02-15" + + svc = TitleService() + svc.initialize() + + mock_client.assert_called_once_with( + azure_endpoint="https://foundry.openai.azure.com", + model="gpt-4o-foundry", + api_version="2024-02-15", + credential=mock_credential, + ) + assert svc._initialized is True