Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 40 additions & 99 deletions src/processor/src/libs/agent_framework/agent_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
AgentMiddleware,
BaseChatClient,
ChatMiddleware,
ChatOptions,
ContextProvider,
FunctionTool,
ToolMode,
Expand Down Expand Up @@ -442,61 +441,32 @@ def build(self) -> Agent:
async with agent:
response = await agent.run("Hello!")
"""
# Build default_options from model parameters
options_dict: dict[str, Any] = {}
if self._frequency_penalty is not None:
options_dict["frequency_penalty"] = self._frequency_penalty
if self._logit_bias is not None:
options_dict["logit_bias"] = self._logit_bias
if self._max_tokens is not None:
options_dict["max_tokens"] = self._max_tokens
if self._metadata is not None:
options_dict["metadata"] = self._metadata
if self._model_id is not None:
options_dict["model"] = self._model_id
if self._presence_penalty is not None:
options_dict["presence_penalty"] = self._presence_penalty
if self._response_format is not None:
options_dict["response_format"] = self._response_format
if self._seed is not None:
options_dict["seed"] = self._seed
if self._stop is not None:
options_dict["stop"] = self._stop
if self._store is not None:
options_dict["store"] = self._store
if self._temperature is not None:
options_dict["temperature"] = self._temperature
if self._tool_choice is not None:
options_dict["tool_choice"] = self._tool_choice
if self._top_p is not None:
options_dict["top_p"] = self._top_p
if self._user is not None:
options_dict["user"] = self._user
if self._additional_chat_options:
options_dict.update(self._additional_chat_options)

default_options = ChatOptions(**options_dict) if options_dict else None

# Agent expects context_providers as a Sequence; wrap single instance in a list
ctx_providers = self._context_providers
if ctx_providers is not None and not isinstance(ctx_providers, list):
ctx_providers = [ctx_providers]

# Agent expects middleware as a Sequence; wrap single instance in a list
mw = self._middleware
if mw is not None and not isinstance(mw, list):
mw = [mw]

return Agent(
self._chat_client,
chat_client=self._chat_client,
instructions=self._instructions,
id=self._id,
name=self._name,
description=self._description,
chat_message_store_factory=self._chat_message_store_factory,
conversation_id=self._conversation_id,
context_providers=self._context_providers,
middleware=self._middleware,
frequency_penalty=self._frequency_penalty,
logit_bias=self._logit_bias,
max_tokens=self._max_tokens,
metadata=self._metadata,
model_id=self._model_id,
presence_penalty=self._presence_penalty,
response_format=self._response_format,
seed=self._seed,
stop=self._stop,
store=self._store,
temperature=self._temperature,
tool_choice=self._tool_choice,
tools=self._tools,
default_options=default_options,
context_providers=ctx_providers,
middleware=mw,
top_p=self._top_p,
user=self._user,
additional_chat_options=self._additional_chat_options,
**self._kwargs,
)

Expand Down Expand Up @@ -785,60 +755,31 @@ def create_agent(
``async with`` to ensure proper initialization and cleanup via the Agent's
async context manager protocol.
"""
# Build default_options from model parameters
opts: dict[str, Any] = {}
if frequency_penalty is not None:
opts["frequency_penalty"] = frequency_penalty
if logit_bias is not None:
opts["logit_bias"] = logit_bias
if max_tokens is not None:
opts["max_tokens"] = max_tokens
if metadata is not None:
opts["metadata"] = metadata
if model_id is not None:
opts["model"] = model_id
if presence_penalty is not None:
opts["presence_penalty"] = presence_penalty
if response_format is not None:
opts["response_format"] = response_format
if seed is not None:
opts["seed"] = seed
if stop is not None:
opts["stop"] = stop
if store is not None:
opts["store"] = store
if temperature is not None:
opts["temperature"] = temperature
if tool_choice is not None:
opts["tool_choice"] = tool_choice
if top_p is not None:
opts["top_p"] = top_p
if user is not None:
opts["user"] = user
if additional_chat_options:
opts.update(additional_chat_options)

default_options = ChatOptions(**opts) if opts else None

# Agent expects context_providers as a Sequence; wrap single instance in a list
ctx_providers = context_providers
if ctx_providers is not None and not isinstance(ctx_providers, list):
ctx_providers = [ctx_providers]

# Agent expects middleware as a Sequence; wrap single instance in a list
mw = middleware
if mw is not None and not isinstance(mw, list):
mw = [mw]

return Agent(
chat_client,
chat_client=chat_client,
instructions=instructions,
id=id,
name=name,
description=description,
chat_message_store_factory=chat_message_store_factory,
conversation_id=conversation_id,
context_providers=context_providers,
middleware=middleware,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
metadata=metadata,
model_id=model_id,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
store=store,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
default_options=default_options,
context_providers=ctx_providers,
middleware=mw,
top_p=top_p,
user=user,
additional_chat_options=additional_chat_options,
**kwargs,
)
10 changes: 5 additions & 5 deletions src/processor/src/libs/agent_framework/agent_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

from typing import Any, Callable, MutableMapping, Sequence

from agent_framework import FunctionTool, MCPStdioTool, MCPStreamableHTTPTool
from agent_framework import FunctionTool
from jinja2 import Template
from openai import BaseModel
Comment thread
Roopan-Microsoft marked this conversation as resolved.
from pydantic import Field

from .agent_framework_helper import AgentFrameworkHelper, ClientType

ToolType = FunctionTool | MCPStreamableHTTPTool | MCPStdioTool | Callable[..., Any] | MutableMapping[str, Any]


class AgentInfo(BaseModel):
agent_name: str
Expand All @@ -23,8 +21,10 @@ class AgentInfo(BaseModel):
agent_instruction: str | None = Field(default=None)
agent_framework_helper: AgentFrameworkHelper | None = Field(default=None)
tools: (
ToolType
| Sequence[ToolType]
FunctionTool
| Callable[..., Any]
| MutableMapping[str, Any]
| Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]]
| None
) = Field(default=None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,117 +325,6 @@ def _bool(name: str, default: bool) -> bool:
)


def _get_content_items(message: Any) -> list[Any]:
"""Return the list of content items from a message, or empty list."""
contents = None
if isinstance(message, dict):
contents = message.get("contents") or message.get("content")
else:
contents = getattr(message, "contents", None) or getattr(message, "content", None)
if isinstance(contents, list):
return contents
return []


def _remove_orphan_tool_messages(messages: list[Any]) -> list[Any]:
"""Remove messages with orphaned function_call or function_result items.

The Responses API requires every function_call in the input to have a
corresponding function_call_output (function_result). If context trimming
breaks these pairs, the API rejects the request.
"""
# Collect call_ids for function_calls and function_results
call_ids_with_call: set[str] = set()
call_ids_with_result: set[str] = set()

for m in messages:
for item in _get_content_items(m):
item_type = None
call_id = None
if isinstance(item, dict):
item_type = item.get("type")
call_id = item.get("call_id")
else:
item_type = getattr(item, "type", None)
call_id = getattr(item, "call_id", None)
if not call_id:
continue
if item_type == "function_call":
call_ids_with_call.add(call_id)
elif item_type == "function_result":
call_ids_with_result.add(call_id)

# Identify orphaned call_ids
orphaned_calls = call_ids_with_call - call_ids_with_result
orphaned_results = call_ids_with_result - call_ids_with_call

if not orphaned_calls and not orphaned_results:
return messages

logger.warning(
"[AOAI_CTX_TRIM] removing orphaned tool messages: %d orphaned calls, %d orphaned results",
len(orphaned_calls),
len(orphaned_results),
)

# Remove messages that ONLY contain orphaned tool items
cleaned: list[Any] = []
for m in messages:
items = _get_content_items(m)
if not items:
cleaned.append(m)
continue

has_orphan = False
has_non_orphan = False
for item in items:
item_type = None
call_id = None
if isinstance(item, dict):
item_type = item.get("type")
call_id = item.get("call_id")
else:
item_type = getattr(item, "type", None)
call_id = getattr(item, "call_id", None)
if call_id and item_type == "function_call" and call_id in orphaned_calls:
has_orphan = True
elif call_id and item_type == "function_result" and call_id in orphaned_results:
has_orphan = True
else:
has_non_orphan = True

if has_orphan and not has_non_orphan:
# Message contains ONLY orphaned tool items — drop it entirely
continue
elif has_orphan and has_non_orphan:
# Message has both orphan and non-orphan content.
# Drop orphaned items if possible, keeping the rest.
if isinstance(items, list) and not isinstance(m, dict):
# Filter out orphaned content items from the message
filtered = []
for item in items:
item_type = getattr(item, "type", None)
call_id = getattr(item, "call_id", None)
if call_id and item_type == "function_call" and call_id in orphaned_calls:
continue
if call_id and item_type == "function_result" and call_id in orphaned_results:
continue
filtered.append(item)
if filtered:
try:
m.contents = filtered
except Exception:
pass
cleaned.append(m)
# else: drop message entirely if no content remains
else:
cleaned.append(m)
else:
cleaned.append(m)

return cleaned


def _trim_messages(
messages: MutableSequence[Any], *, cfg: ContextTrimConfig
) -> list[Any]:
Expand Down Expand Up @@ -525,11 +414,6 @@ def _total_chars(msgs: list[Any]) -> int:
break
combined.pop(drop_index)

# Phase final: Remove orphaned tool call / tool result messages.
# The Responses API requires every function_call to have a matching
# function_call_output. Trimming may break these pairs.
combined = _remove_orphan_tool_messages(combined)

return combined


Expand Down Expand Up @@ -655,27 +539,12 @@ def __init__(
# Map legacy params to OpenAIChatClient params
if deployment_name and "model" not in kwargs:
kwargs["model"] = deployment_name
if endpoint and not kwargs.get("azure_endpoint"):
if endpoint and "azure_endpoint" not in kwargs:
kwargs["azure_endpoint"] = endpoint
if ad_token_provider and kwargs.get("credential") is None:
kwargs["credential"] = ad_token_provider

# Remove None-valued keys that would conflict with env-based settings
for k in list(kwargs):
if kwargs[k] is None:
del kwargs[k]

super().__init__(*args, **kwargs)

# OpenAIChatClient appends /v1/ to azure_endpoint but Azure AI Foundry
# endpoints expect /openai/responses (without /v1/). Fix the base URL.
if hasattr(self, "client") and self.client is not None:
base = str(self.client.base_url)
if "/openai/v1/" in base:
import httpx
corrected = base.replace("/openai/v1/", "/openai/")
self.client._base_url = httpx.URL(corrected)

self._retry_config = retry_config or RateLimitRetryConfig.from_env()
self._context_trim_config = ContextTrimConfig.from_env()

Expand Down
Loading
Loading