From 74234873ff086c90c269ceb81577066f78a05142 Mon Sep 17 00:00:00 2001 From: Roopan-Microsoft <168007406+Roopan-Microsoft@users.noreply.github.com> Date: Wed, 10 Jun 2026 17:46:13 +0530 Subject: [PATCH] Revert "fix: enhance message handling and context management in orchestrators" --- .../src/libs/agent_framework/agent_builder.py | 139 +++++------------- .../src/libs/agent_framework/agent_info.py | 10 +- .../azure_openai_response_retry.py | 133 +---------------- .../agent_framework/groupchat_orchestrator.py | 15 +- .../shared_memory_context_provider.py | 57 +++---- .../src/libs/base/orchestrator_base.py | 26 ++-- .../src/steps/migration_processor.py | 36 +++-- .../agent_framework/test_agent_builder.py | 22 +-- .../test_groupchat_orchestrator_internals.py | 8 +- .../test_shared_memory_context_provider.py | 119 +++++++-------- .../steps/test_migration_processor_run.py | 4 +- 11 files changed, 184 insertions(+), 385 deletions(-) diff --git a/src/processor/src/libs/agent_framework/agent_builder.py b/src/processor/src/libs/agent_framework/agent_builder.py index c9b747d0..6a7e4409 100644 --- a/src/processor/src/libs/agent_framework/agent_builder.py +++ b/src/processor/src/libs/agent_framework/agent_builder.py @@ -11,7 +11,6 @@ AgentMiddleware, BaseChatClient, ChatMiddleware, - ChatOptions, ContextProvider, FunctionTool, ToolMode, @@ -442,61 +441,32 @@ def build(self) -> Agent: async with agent: response = await agent.run("Hello!") """ - # Build default_options from model parameters - options_dict: dict[str, Any] = {} - if self._frequency_penalty is not None: - options_dict["frequency_penalty"] = self._frequency_penalty - if self._logit_bias is not None: - options_dict["logit_bias"] = self._logit_bias - if self._max_tokens is not None: - options_dict["max_tokens"] = self._max_tokens - if self._metadata is not None: - options_dict["metadata"] = self._metadata - if self._model_id is not None: - options_dict["model"] = self._model_id - if self._presence_penalty is not None: - options_dict["presence_penalty"] = self._presence_penalty - if self._response_format is not None: - options_dict["response_format"] = self._response_format - if self._seed is not None: - options_dict["seed"] = self._seed - if self._stop is not None: - options_dict["stop"] = self._stop - if self._store is not None: - options_dict["store"] = self._store - if self._temperature is not None: - options_dict["temperature"] = self._temperature - if self._tool_choice is not None: - options_dict["tool_choice"] = self._tool_choice - if self._top_p is not None: - options_dict["top_p"] = self._top_p - if self._user is not None: - options_dict["user"] = self._user - if self._additional_chat_options: - options_dict.update(self._additional_chat_options) - - default_options = ChatOptions(**options_dict) if options_dict else None - - # Agent expects context_providers as a Sequence; wrap single instance in a list - ctx_providers = self._context_providers - if ctx_providers is not None and not isinstance(ctx_providers, list): - ctx_providers = [ctx_providers] - - # Agent expects middleware as a Sequence; wrap single instance in a list - mw = self._middleware - if mw is not None and not isinstance(mw, list): - mw = [mw] - return Agent( - self._chat_client, + chat_client=self._chat_client, instructions=self._instructions, id=self._id, name=self._name, description=self._description, + chat_message_store_factory=self._chat_message_store_factory, + conversation_id=self._conversation_id, + context_providers=self._context_providers, + middleware=self._middleware, + frequency_penalty=self._frequency_penalty, + logit_bias=self._logit_bias, + max_tokens=self._max_tokens, + metadata=self._metadata, + model_id=self._model_id, + presence_penalty=self._presence_penalty, + response_format=self._response_format, + seed=self._seed, + stop=self._stop, + store=self._store, + temperature=self._temperature, + tool_choice=self._tool_choice, tools=self._tools, - default_options=default_options, - context_providers=ctx_providers, - middleware=mw, + top_p=self._top_p, + user=self._user, + additional_chat_options=self._additional_chat_options, **self._kwargs, ) @@ -785,60 +755,31 @@ def create_agent( ``async with`` to ensure proper initialization and cleanup via the Agent's async context manager protocol. """ - # Build default_options from model parameters - opts: dict[str, Any] = {} - if frequency_penalty is not None: - opts["frequency_penalty"] = frequency_penalty - if logit_bias is not None: - opts["logit_bias"] = logit_bias - if max_tokens is not None: - opts["max_tokens"] = max_tokens - if metadata is not None: - opts["metadata"] = metadata - if model_id is not None: - opts["model"] = model_id - if presence_penalty is not None: - opts["presence_penalty"] = presence_penalty - if response_format is not None: - opts["response_format"] = response_format - if seed is not None: - opts["seed"] = seed - if stop is not None: - opts["stop"] = stop - if store is not None: - opts["store"] = store - if temperature is not None: - opts["temperature"] = temperature - if tool_choice is not None: - opts["tool_choice"] = tool_choice - if top_p is not None: - opts["top_p"] = top_p - if user is not None: - opts["user"] = user - if additional_chat_options: - opts.update(additional_chat_options) - - default_options = ChatOptions(**opts) if opts else None - - # Agent expects context_providers as a Sequence; wrap single instance in a list - ctx_providers = context_providers - if ctx_providers is not None and not isinstance(ctx_providers, list): - ctx_providers = [ctx_providers] - - # Agent expects middleware as a Sequence; wrap single instance in a list - mw = middleware - if mw is not None and not isinstance(mw, list): - mw = [mw] - return Agent( - chat_client, + chat_client=chat_client, instructions=instructions, id=id, name=name, description=description, + chat_message_store_factory=chat_message_store_factory, + conversation_id=conversation_id, + context_providers=context_providers, + middleware=middleware, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + metadata=metadata, + model_id=model_id, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + store=store, + temperature=temperature, + tool_choice=tool_choice, tools=tools, - default_options=default_options, - context_providers=ctx_providers, - middleware=mw, + top_p=top_p, + user=user, + additional_chat_options=additional_chat_options, **kwargs, ) diff --git a/src/processor/src/libs/agent_framework/agent_info.py b/src/processor/src/libs/agent_framework/agent_info.py index 1ae3def7..82f657b6 100644 --- a/src/processor/src/libs/agent_framework/agent_info.py +++ b/src/processor/src/libs/agent_framework/agent_info.py @@ -5,15 +5,13 @@ from typing import Any, Callable, MutableMapping, Sequence -from agent_framework import FunctionTool, MCPStdioTool, MCPStreamableHTTPTool +from agent_framework import FunctionTool from jinja2 import Template from openai import BaseModel from pydantic import Field from .agent_framework_helper import AgentFrameworkHelper, ClientType -ToolType = FunctionTool | MCPStreamableHTTPTool | MCPStdioTool | Callable[..., Any] | MutableMapping[str, Any] - class AgentInfo(BaseModel): agent_name: str @@ -23,8 +21,10 @@ class AgentInfo(BaseModel): agent_instruction: str | None = Field(default=None) agent_framework_helper: AgentFrameworkHelper | None = Field(default=None) tools: ( - ToolType - | Sequence[ToolType] + FunctionTool + | Callable[..., Any] + | MutableMapping[str, Any] + | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]] | None ) = Field(default=None) diff --git a/src/processor/src/libs/agent_framework/azure_openai_response_retry.py b/src/processor/src/libs/agent_framework/azure_openai_response_retry.py index c691ae7b..93695000 100644 --- a/src/processor/src/libs/agent_framework/azure_openai_response_retry.py +++ b/src/processor/src/libs/agent_framework/azure_openai_response_retry.py @@ -325,117 +325,6 @@ def _bool(name: str, default: bool) -> bool: ) -def _get_content_items(message: Any) -> list[Any]: - """Return the list of content items from a message, or empty list.""" - contents = None - if isinstance(message, dict): - contents = message.get("contents") or message.get("content") - else: - contents = getattr(message, "contents", None) or getattr(message, "content", None) - if isinstance(contents, list): - return contents - return [] - - -def _remove_orphan_tool_messages(messages: list[Any]) -> list[Any]: - """Remove messages with orphaned function_call or function_result items. - - The Responses API requires every function_call in the input to have a - corresponding function_call_output (function_result). If context trimming - breaks these pairs, the API rejects the request. - """ - # Collect call_ids for function_calls and function_results - call_ids_with_call: set[str] = set() - call_ids_with_result: set[str] = set() - - for m in messages: - for item in _get_content_items(m): - item_type = None - call_id = None - if isinstance(item, dict): - item_type = item.get("type") - call_id = item.get("call_id") - else: - item_type = getattr(item, "type", None) - call_id = getattr(item, "call_id", None) - if not call_id: - continue - if item_type == "function_call": - call_ids_with_call.add(call_id) - elif item_type == "function_result": - call_ids_with_result.add(call_id) - - # Identify orphaned call_ids - orphaned_calls = call_ids_with_call - call_ids_with_result - orphaned_results = call_ids_with_result - call_ids_with_call - - if not orphaned_calls and not orphaned_results: - return messages - - logger.warning( - "[AOAI_CTX_TRIM] removing orphaned tool messages: %d orphaned calls, %d orphaned results", - len(orphaned_calls), - len(orphaned_results), - ) - - # Remove messages that ONLY contain orphaned tool items - cleaned: list[Any] = [] - for m in messages: - items = _get_content_items(m) - if not items: - cleaned.append(m) - continue - - has_orphan = False - has_non_orphan = False - for item in items: - item_type = None - call_id = None - if isinstance(item, dict): - item_type = item.get("type") - call_id = item.get("call_id") - else: - item_type = getattr(item, "type", None) - call_id = getattr(item, "call_id", None) - if call_id and item_type == "function_call" and call_id in orphaned_calls: - has_orphan = True - elif call_id and item_type == "function_result" and call_id in orphaned_results: - has_orphan = True - else: - has_non_orphan = True - - if has_orphan and not has_non_orphan: - # Message contains ONLY orphaned tool items — drop it entirely - continue - elif has_orphan and has_non_orphan: - # Message has both orphan and non-orphan content. - # Drop orphaned items if possible, keeping the rest. - if isinstance(items, list) and not isinstance(m, dict): - # Filter out orphaned content items from the message - filtered = [] - for item in items: - item_type = getattr(item, "type", None) - call_id = getattr(item, "call_id", None) - if call_id and item_type == "function_call" and call_id in orphaned_calls: - continue - if call_id and item_type == "function_result" and call_id in orphaned_results: - continue - filtered.append(item) - if filtered: - try: - m.contents = filtered - except Exception: - pass - cleaned.append(m) - # else: drop message entirely if no content remains - else: - cleaned.append(m) - else: - cleaned.append(m) - - return cleaned - - def _trim_messages( messages: MutableSequence[Any], *, cfg: ContextTrimConfig ) -> list[Any]: @@ -525,11 +414,6 @@ def _total_chars(msgs: list[Any]) -> int: break combined.pop(drop_index) - # Phase final: Remove orphaned tool call / tool result messages. - # The Responses API requires every function_call to have a matching - # function_call_output. Trimming may break these pairs. - combined = _remove_orphan_tool_messages(combined) - return combined @@ -655,27 +539,12 @@ def __init__( # Map legacy params to OpenAIChatClient params if deployment_name and "model" not in kwargs: kwargs["model"] = deployment_name - if endpoint and not kwargs.get("azure_endpoint"): + if endpoint and "azure_endpoint" not in kwargs: kwargs["azure_endpoint"] = endpoint if ad_token_provider and kwargs.get("credential") is None: kwargs["credential"] = ad_token_provider - # Remove None-valued keys that would conflict with env-based settings - for k in list(kwargs): - if kwargs[k] is None: - del kwargs[k] - super().__init__(*args, **kwargs) - - # OpenAIChatClient appends /v1/ to azure_endpoint but Azure AI Foundry - # endpoints expect /openai/responses (without /v1/). Fix the base URL. - if hasattr(self, "client") and self.client is not None: - base = str(self.client.base_url) - if "/openai/v1/" in base: - import httpx - corrected = base.replace("/openai/v1/", "/openai/") - self.client._base_url = httpx.URL(corrected) - self._retry_config = retry_config or RateLimitRetryConfig.from_env() self._context_trim_config = ContextTrimConfig.from_env() diff --git a/src/processor/src/libs/agent_framework/groupchat_orchestrator.py b/src/processor/src/libs/agent_framework/groupchat_orchestrator.py index cc38a086..ebebeceb 100644 --- a/src/processor/src/libs/agent_framework/groupchat_orchestrator.py +++ b/src/processor/src/libs/agent_framework/groupchat_orchestrator.py @@ -28,9 +28,9 @@ Role, SupportsAgentRun, Workflow, + WorkflowBuilder as GroupChatBuilder, WorkflowEvent, ) -from agent_framework_orchestrations import GroupChatBuilder from mem0 import AsyncMemory from pydantic import BaseModel, ValidationError @@ -491,7 +491,7 @@ async def run_stream( # Execute with streaming conversation: list[Message] = [] - async for event in group_chat_workflow.run(task_prompt, stream=True): + async for event in group_chat_workflow.run_stream(task_prompt): # Enforce wall-clock timeout if configured. if self.max_seconds is not None: elapsed = (datetime.now() - start_time).total_seconds() @@ -1114,10 +1114,9 @@ async def _build_groupchat(self) -> Workflow: ] return ( - GroupChatBuilder( - participants=participants, - orchestrator_agent=coordinator, - ) + GroupChatBuilder() + .set_manager(coordinator) + .participants(participants) .build() ) @@ -1142,7 +1141,7 @@ async def _generate_final_result( result = await result_generator.run( final_conversation, - options={"response_format": result_format}, + response_format=result_format, ) text = result.messages[-1].text @@ -1175,7 +1174,7 @@ async def _generate_final_result( ) retry_result = await result_generator.run( retry_conversation, - options={"response_format": result_format}, + response_format=result_format, ) retry_text = retry_result.messages[-1].text retry_json_payload = self._extract_first_json_payload(retry_text) diff --git a/src/processor/src/libs/agent_framework/shared_memory_context_provider.py b/src/processor/src/libs/agent_framework/shared_memory_context_provider.py index df78ebe0..fd95a5de 100644 --- a/src/processor/src/libs/agent_framework/shared_memory_context_provider.py +++ b/src/processor/src/libs/agent_framework/shared_memory_context_provider.py @@ -77,7 +77,6 @@ def __init__( top_k: Number of relevant memories to retrieve per turn. score_threshold: Minimum similarity score for memory retrieval. """ - super().__init__(source_id=f"shared_memory_{agent_name}_{step}") self._memory_store = memory_store self._agent_name = agent_name self._step = step @@ -97,14 +96,11 @@ def __init__( break self._prior_steps = _STEP_ORDER[:step_idx] if step_idx else [] - async def before_run( + async def invoking( self, - *, - agent, - session, - context, - state, - ) -> None: + messages: Message | MutableSequence[Message], + **kwargs, + ) -> Context: """Called before the agent's LLM call. Injects relevant shared memories. Only searches memories from PREVIOUS steps. Within the current step, @@ -112,13 +108,12 @@ async def before_run( """ # Skip if this is the first step (no prior memories exist) if not self._prior_steps: - return + return Context() - # Extract query from the most recent messages in context - messages = context.get_messages() + # Extract query from the most recent messages query = self._extract_query(messages) if not query: - return + return Context() try: memories = await self._memory_store.search( @@ -132,15 +127,15 @@ async def before_run( self._agent_name, e, ) - return + return Context() if not memories: - return + return Context() # Format memories into context instructions formatted = self._format_memories(memories) if not formatted: - return + return Context() instructions = f"{self.DEFAULT_CONTEXT_PROMPT}\n\n{formatted}" @@ -152,15 +147,14 @@ async def before_run( len(instructions), ) - context.extend_instructions(self.source_id, instructions) + return Context(instructions=instructions) - async def after_run( + async def invoked( self, - *, - agent, - session, - context, - state, + request_messages: Message | Sequence[Message], + response_messages: Message | Sequence[Message] | None = None, + invoke_exception: Exception | None = None, + **kwargs, ) -> None: """Called after the agent's LLM response. Buffers the response for storage. @@ -169,26 +163,33 @@ async def after_run( This means only the agent's last response per step gets stored, which is the most complete and useful summary. """ - response = context.response - if response is None: + if invoke_exception is not None: + logger.debug( + "[MEMORY] invoked() skipped for %s — exception: %s", + self._agent_name, + invoke_exception, + ) + return + + if response_messages is None: logger.debug( - "[MEMORY] after_run() skipped for %s — no response", + "[MEMORY] invoked() skipped for %s — no response_messages", self._agent_name, ) return # Extract text from response - content = response.text if hasattr(response, "text") else None + content = self._extract_text(response_messages) if not content or len(content) < MIN_CONTENT_LENGTH_TO_STORE: logger.debug( - "[MEMORY] after_run() skipped for %s — content too short (%d chars)", + "[MEMORY] invoked() skipped for %s — content too short (%d chars)", self._agent_name, len(content) if content else 0, ) return logger.info( - "[MEMORY] after_run() buffering for %s (step=%s, %d chars)", + "[MEMORY] invoked() buffering for %s (step=%s, %d chars)", self._agent_name, self._step, len(content), diff --git a/src/processor/src/libs/base/orchestrator_base.py b/src/processor/src/libs/base/orchestrator_base.py index 5c53a961..420664a7 100644 --- a/src/processor/src/libs/base/orchestrator_base.py +++ b/src/processor/src/libs/base/orchestrator_base.py @@ -9,7 +9,7 @@ from abc import abstractmethod from typing import Any, Callable, Generic, MutableMapping, Sequence, TypeVar -from agent_framework import Agent, FunctionTool, InMemoryHistoryProvider, ToolResultCompactionStrategy +from agent_framework import Agent, FunctionTool, ToolResultCompactionStrategy from libs.agent_framework.agent_builder import AgentBuilder from libs.agent_framework.agent_framework_helper import ClientType @@ -169,7 +169,6 @@ async def create_agents( AgentBuilder(agent_client) .with_name(agent_info.agent_name) .with_instructions(instruction) - .with_store(False) ) # Only attach tools when provided. (Coordinator should typically have none.) @@ -207,19 +206,18 @@ async def create_agents( .with_tool_choice("none") ) - # Attach context providers to expert agents + # Attach shared memory context provider to expert agents # (not Coordinator, not ResultGenerator — they don't need memory) - if agent_info.agent_name not in ("Coordinator", "ResultGenerator"): - providers: list = [InMemoryHistoryProvider()] - if self.memory_store is not None: - providers.append( - SharedMemoryContextProvider( - memory_store=self.memory_store, - agent_name=agent_info.agent_name, - step=self.step_name, - ) - ) - builder = builder.with_context_providers(providers) + if ( + self.memory_store is not None + and agent_info.agent_name not in ("Coordinator", "ResultGenerator") + ): + memory_provider = SharedMemoryContextProvider( + memory_store=self.memory_store, + agent_name=agent_info.agent_name, + step=self.step_name, + ) + builder = builder.with_context_providers(memory_provider) agent = builder.build() agents[agent_info.agent_name] = agent diff --git a/src/processor/src/steps/migration_processor.py b/src/processor/src/steps/migration_processor.py index 7c7328fa..c73f570a 100644 --- a/src/processor/src/steps/migration_processor.py +++ b/src/processor/src/steps/migration_processor.py @@ -157,18 +157,30 @@ def _init_workflow(self) -> Workflow: Workflow The built workflow ready to execute. """ - analysis = AnalysisExecutor(id="analysis", app_context=self.app_context) - design = DesignExecutor(id="design", app_context=self.app_context) - yaml_convert = YamlConvertExecutor(id="yaml", app_context=self.app_context) - documentation = DocumentationExecutor( - id="documentation", app_context=self.app_context - ) - workflow = ( - WorkflowBuilder(start_executor=analysis) - .add_edge(analysis, design) - .add_edge(design, yaml_convert) - .add_edge(yaml_convert, documentation) + WorkflowBuilder() + .register_executor( + lambda: AnalysisExecutor(id="analysis", app_context=self.app_context), + name="analysis", + ) + .register_executor( + lambda: DesignExecutor(id="design", app_context=self.app_context), + name="design", + ) + .register_executor( + lambda: YamlConvertExecutor(id="yaml", app_context=self.app_context), + name="yaml", + ) + .register_executor( + lambda: DocumentationExecutor( + id="documentation", app_context=self.app_context + ), + name="documentation", + ) + .set_start_executor("analysis") + .add_edge("analysis", "design") + .add_edge("design", "yaml") + .add_edge("yaml", "documentation") .build() ) @@ -346,7 +358,7 @@ async def _generate_report_summary( "top_remediations": remediation_titles, } - async for event in self.workflow.run(input_data, stream=True): + async for event in self.workflow.run_stream(input_data): if event.type == "started": logger.info("Workflow started (%s)", event.origin.value) diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py b/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py index f6e99c36..cbfede63 100644 --- a/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py +++ b/src/processor/src/tests/unit/libs/agent_framework/test_agent_builder.py @@ -157,17 +157,15 @@ def test_build_passes_all_state_to_chat_agent(self): .build() ) assert agent is mock_chat.return_value - args = mock_chat.call_args.args kwargs = mock_chat.call_args.kwargs - assert args[0] is chat_client + assert kwargs["chat_client"] is chat_client assert kwargs["instructions"] == "inst" assert kwargs["id"] == "id1" assert kwargs["name"] == "name1" assert kwargs["description"] == "desc1" - opts = kwargs["default_options"] - assert opts["temperature"] == 0.3 - assert opts["max_tokens"] == 100 - assert opts["tool_choice"] == "auto" + assert kwargs["temperature"] == 0.3 + assert kwargs["max_tokens"] == 100 + assert kwargs["tool_choice"] == "auto" assert kwargs["extra"] == 42 @@ -182,13 +180,11 @@ def test_create_agent_invokes_chat_agent(self): temperature=0.4, ) assert agent is mock_chat.return_value - args = mock_chat.call_args.args kwargs = mock_chat.call_args.kwargs - assert args[0] is chat_client + assert kwargs["chat_client"] is chat_client assert kwargs["instructions"] == "i" assert kwargs["name"] == "n" - opts = kwargs["default_options"] - assert opts["temperature"] == 0.4 + assert kwargs["temperature"] == 0.4 def test_create_agent_by_agentinfo_uses_helper_and_creates_client(self): # Build a fake AgentInfo with the minimum surface used by the method @@ -219,14 +215,12 @@ def test_create_agent_by_agentinfo_uses_helper_and_creates_client(self): assert agent is mock_chat.return_value helper.settings.get_service_config.assert_called_once_with("default") helper.create_client.assert_called_once() - args = mock_chat.call_args.args ck = mock_chat.call_args.kwargs - assert args[0] == "client-instance" + assert ck["chat_client"] == "client-instance" assert ck["instructions"] == "instr" assert ck["name"] == "A" assert ck["description"] == "D" - opts = ck["default_options"] - assert opts["temperature"] == 0.2 + assert ck["temperature"] == 0.2 def test_create_agent_by_agentinfo_falls_back_to_system_prompt(self): helper = MagicMock() diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py b/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py index 1b2bb182..5ccbce22 100644 --- a/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py +++ b/src/processor/src/tests/unit/libs/agent_framework/test_groupchat_orchestrator_internals.py @@ -743,14 +743,16 @@ def test_build_groupchat_invokes_builder(self): }) with patch("libs.agent_framework.groupchat_orchestrator.GroupChatBuilder") as MockBuilder: built = MagicMock() + built.set_manager.return_value = built + built.participants.return_value = built built.build.return_value = "wf" MockBuilder.return_value = built wf = _run(orch._build_groupchat()) assert wf == "wf" # ResultGenerator excluded from participants - kwargs = MockBuilder.call_args.kwargs - assert "arch" in kwargs["participants"] - assert "rg" not in kwargs["participants"] + kwargs = built.participants.call_args.args[0] + assert "arch" in kwargs + assert "rg" not in kwargs # ----------------------------------------------------------------------------- diff --git a/src/processor/src/tests/unit/libs/agent_framework/test_shared_memory_context_provider.py b/src/processor/src/tests/unit/libs/agent_framework/test_shared_memory_context_provider.py index 3398c764..ab2bc8b2 100644 --- a/src/processor/src/tests/unit/libs/agent_framework/test_shared_memory_context_provider.py +++ b/src/processor/src/tests/unit/libs/agent_framework/test_shared_memory_context_provider.py @@ -61,36 +61,8 @@ def _make_provider(store=None): ), store -def _make_context(messages=None, response_text=None): - """Create a mock SessionContext for before_run/after_run calls.""" - ctx = MagicMock() - ctx.get_messages = MagicMock(return_value=messages or []) - ctx.extend_instructions = MagicMock() - if response_text is not None: - ctx.response = MagicMock() - ctx.response.text = response_text - else: - ctx.response = None - return ctx - - -async def _call_before_run(provider, messages): - """Helper to call before_run and return the instructions that were injected.""" - ctx = _make_context(messages=messages) - await provider.before_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={}) - if ctx.extend_instructions.called: - return ctx.extend_instructions.call_args[0][1] # second positional arg = instructions - return None - - -async def _call_after_run(provider, response_text): - """Helper to call after_run with a response.""" - ctx = _make_context(response_text=response_text) - await provider.after_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={}) - - # --------------------------------------------------------------------------- -# before_run() — Pre-LLM memory injection +# invoking() — Pre-LLM memory injection # --------------------------------------------------------------------------- @@ -103,11 +75,11 @@ async def _run(): ] messages = [_make_chat_message("How should we handle storage configuration?")] - instructions = await _call_before_run(provider, messages) + context = await provider.invoking(messages) - assert instructions is not None - assert "GKE Filestore CSI" in instructions - assert "Azure Files for AKS" in instructions + assert context.instructions is not None + assert "GKE Filestore CSI" in context.instructions + assert "Azure Files for AKS" in context.instructions store.search.assert_called_once() asyncio.run(_run()) @@ -116,8 +88,9 @@ async def _run(): def test_invoking_empty_messages_returns_empty(): async def _run(): provider, _ = _make_provider() - instructions = await _call_before_run(provider, []) - assert instructions is None + context = await provider.invoking([]) + assert context.instructions is None + assert getattr(context, "messages", []) == [] asyncio.run(_run()) @@ -128,8 +101,8 @@ async def _run(): store.search.return_value = [] messages = [_make_chat_message("What is the overall migration plan for AKS?")] - instructions = await _call_before_run(provider, messages) - assert instructions is None + context = await provider.invoking(messages) + assert context.instructions is None asyncio.run(_run()) @@ -140,8 +113,8 @@ async def _run(): store.search.side_effect = Exception("search failed") messages = [_make_chat_message("What is the networking plan for AKS?")] - instructions = await _call_before_run(provider, messages) - assert instructions is None + context = await provider.invoking(messages) + assert context.instructions is None asyncio.run(_run()) @@ -152,7 +125,7 @@ async def _run(): long_text = "x" * 5000 messages = [_make_chat_message(long_text)] - await _call_before_run(provider, messages) + await provider.invoking(messages) query = store.search.call_args.kwargs["query"] assert len(query) <= 2000 @@ -169,7 +142,7 @@ async def _run(): _make_chat_message("Latest question about storage"), ] - await _call_before_run(provider, messages) + await provider.invoking(messages) query = store.search.call_args.kwargs["query"] assert "Latest question about storage" in query @@ -186,10 +159,10 @@ async def _run(): store.search.return_value = large_memories messages = [_make_chat_message("What storage configuration should we use for persistent volumes?")] - instructions = await _call_before_run(provider, messages) + context = await provider.invoking(messages) - assert instructions is not None - assert len(instructions) <= MAX_MEMORY_CONTEXT_CHARS + 200 + assert context.instructions is not None + assert len(context.instructions) <= MAX_MEMORY_CONTEXT_CHARS + 200 asyncio.run(_run()) @@ -202,10 +175,10 @@ async def _run(): ] messages = [_make_chat_message("What storage class should we choose for the cluster?")] - instructions = await _call_before_run(provider, messages) + context = await provider.invoking(messages) - assert "Chief Architect" in instructions - assert "design" in instructions + assert "Chief Architect" in context.instructions + assert "design" in context.instructions asyncio.run(_run()) @@ -216,25 +189,26 @@ async def _run(): store.search.return_value = [_make_memory_entry("some memory")] single = _make_chat_message("What about networking configuration for AKS?") - instructions = await _call_before_run(provider, [single]) + context = await provider.invoking(single) - assert instructions is not None + assert context.instructions is not None store.search.assert_called_once() asyncio.run(_run()) # --------------------------------------------------------------------------- -# after_run() — Post-LLM memory storage +# invoked() — Post-LLM memory storage # --------------------------------------------------------------------------- def test_invoked_stores_response(): async def _run(): provider, store = _make_provider() - response_text = "We should use Azure CNI for networking configuration in the AKS cluster" + request = [_make_chat_message("What is the networking plan for AKS?")] + response = [_make_chat_message("We should use Azure CNI for networking configuration in the AKS cluster")] - await _call_after_run(provider, response_text) + await provider.invoked(request, response) await provider.flush() store.add.assert_called_once() @@ -248,9 +222,10 @@ async def _run(): def test_invoked_skips_on_exception(): async def _run(): provider, store = _make_provider() - # after_run with no response simulates exception path - ctx = _make_context(response_text=None) - await provider.after_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={}) + request = [_make_chat_message("Q")] + response = [_make_chat_message("A" * 100)] + + await provider.invoked(request, response, invoke_exception=Exception("fail")) store.add.assert_not_called() asyncio.run(_run()) @@ -259,8 +234,9 @@ async def _run(): def test_invoked_skips_none_response(): async def _run(): provider, store = _make_provider() - ctx = _make_context(response_text=None) - await provider.after_run(agent=MagicMock(), session=MagicMock(), context=ctx, state={}) + request = [_make_chat_message("Q")] + + await provider.invoked(request, None) store.add.assert_not_called() asyncio.run(_run()) @@ -269,8 +245,10 @@ async def _run(): def test_invoked_skips_short_response(): async def _run(): provider, store = _make_provider() - short_text = "x" * (MIN_CONTENT_LENGTH_TO_STORE - 1) - await _call_after_run(provider, short_text) + request = [_make_chat_message("Q")] + short = [_make_chat_message("x" * (MIN_CONTENT_LENGTH_TO_STORE - 1))] + + await provider.invoked(request, short) store.add.assert_not_called() asyncio.run(_run()) @@ -279,8 +257,10 @@ async def _run(): def test_invoked_stores_long_response(): async def _run(): provider, store = _make_provider() - long_text = "x" * (MIN_CONTENT_LENGTH_TO_STORE + 1) - await _call_after_run(provider, long_text) + request = [_make_chat_message("Q")] + long_resp = [_make_chat_message("x" * (MIN_CONTENT_LENGTH_TO_STORE + 1))] + + await provider.invoked(request, long_resp) await provider.flush() store.add.assert_called_once() @@ -290,10 +270,11 @@ async def _run(): def test_invoked_increments_turn_counter(): async def _run(): provider, store = _make_provider() - response_text = "A" * 100 + request = [_make_chat_message("Q")] + response = [_make_chat_message("A" * 100)] - await _call_after_run(provider, response_text) - await _call_after_run(provider, response_text) + await provider.invoked(request, response) + await provider.invoked(request, response) assert provider._turn_counter == 2 asyncio.run(_run()) @@ -303,9 +284,10 @@ def test_invoked_store_failure_does_not_raise(): async def _run(): provider, store = _make_provider() store.add.side_effect = Exception("store failed") - response_text = "A" * 100 + request = [_make_chat_message("Q")] + response = [_make_chat_message("A" * 100)] - await _call_after_run(provider, response_text) + await provider.invoked(request, response) await provider.flush() # Should not raise asyncio.run(_run()) @@ -314,9 +296,10 @@ async def _run(): def test_invoked_with_single_message(): async def _run(): provider, store = _make_provider() - response_text = "We should use Azure CNI Overlay for the networking configuration in AKS" + request = _make_chat_message("What is the question about networking?") + response = _make_chat_message("We should use Azure CNI Overlay for the networking configuration in AKS") - await _call_after_run(provider, response_text) + await provider.invoked(request, response) await provider.flush() store.add.assert_called_once() diff --git a/src/processor/src/tests/unit/steps/test_migration_processor_run.py b/src/processor/src/tests/unit/steps/test_migration_processor_run.py index b73abc8b..683fcc5d 100644 --- a/src/processor/src/tests/unit/steps/test_migration_processor_run.py +++ b/src/processor/src/tests/unit/steps/test_migration_processor_run.py @@ -54,12 +54,12 @@ def _make_processor(events: list, memory_store=None) -> MigrationProcessor: proc._telemetry = telemetry # expose for assertions - async def _stream(_input, **kwargs): + async def _stream(_input): for ev in events: yield ev workflow = MagicMock() - workflow.run = _stream + workflow.run_stream = _stream proc.workflow = workflow # Patch _create_memory_store as an AsyncMock returning the provided value.