Skip to content

Commit 99e941b

Browse files
Merge pull request #277 from microsoft/psl-agentframework
fix: enhance message handling and context management in orchestrators
2 parents 32cac9f + 6f36041 commit 99e941b

11 files changed

Lines changed: 385 additions & 184 deletions

src/processor/src/libs/agent_framework/agent_builder.py

Lines changed: 99 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AgentMiddleware,
1212
BaseChatClient,
1313
ChatMiddleware,
14+
ChatOptions,
1415
ContextProvider,
1516
FunctionTool,
1617
ToolMode,
@@ -441,32 +442,61 @@ def build(self) -> Agent:
441442
async with agent:
442443
response = await agent.run("Hello!")
443444
"""
445+
# Build default_options from model parameters
446+
options_dict: dict[str, Any] = {}
447+
if self._frequency_penalty is not None:
448+
options_dict["frequency_penalty"] = self._frequency_penalty
449+
if self._logit_bias is not None:
450+
options_dict["logit_bias"] = self._logit_bias
451+
if self._max_tokens is not None:
452+
options_dict["max_tokens"] = self._max_tokens
453+
if self._metadata is not None:
454+
options_dict["metadata"] = self._metadata
455+
if self._model_id is not None:
456+
options_dict["model"] = self._model_id
457+
if self._presence_penalty is not None:
458+
options_dict["presence_penalty"] = self._presence_penalty
459+
if self._response_format is not None:
460+
options_dict["response_format"] = self._response_format
461+
if self._seed is not None:
462+
options_dict["seed"] = self._seed
463+
if self._stop is not None:
464+
options_dict["stop"] = self._stop
465+
if self._store is not None:
466+
options_dict["store"] = self._store
467+
if self._temperature is not None:
468+
options_dict["temperature"] = self._temperature
469+
if self._tool_choice is not None:
470+
options_dict["tool_choice"] = self._tool_choice
471+
if self._top_p is not None:
472+
options_dict["top_p"] = self._top_p
473+
if self._user is not None:
474+
options_dict["user"] = self._user
475+
if self._additional_chat_options:
476+
options_dict.update(self._additional_chat_options)
477+
478+
default_options = ChatOptions(**options_dict) if options_dict else None
479+
480+
# Agent expects context_providers as a Sequence; wrap single instance in a list
481+
ctx_providers = self._context_providers
482+
if ctx_providers is not None and not isinstance(ctx_providers, list):
483+
ctx_providers = [ctx_providers]
484+
485+
# Agent expects middleware as a Sequence; wrap single instance in a list
486+
mw = self._middleware
487+
if mw is not None and not isinstance(mw, list):
488+
mw = [mw]
489+
444490
return Agent(
445-
chat_client=self._chat_client,
491+
self._chat_client,
446492
instructions=self._instructions,
447493
id=self._id,
448494
name=self._name,
449495
description=self._description,
450-
chat_message_store_factory=self._chat_message_store_factory,
451-
conversation_id=self._conversation_id,
452-
context_providers=self._context_providers,
453-
middleware=self._middleware,
454-
frequency_penalty=self._frequency_penalty,
455-
logit_bias=self._logit_bias,
456-
max_tokens=self._max_tokens,
457-
metadata=self._metadata,
458-
model_id=self._model_id,
459-
presence_penalty=self._presence_penalty,
460-
response_format=self._response_format,
461-
seed=self._seed,
462-
stop=self._stop,
463-
store=self._store,
464-
temperature=self._temperature,
465-
tool_choice=self._tool_choice,
466496
tools=self._tools,
467-
top_p=self._top_p,
468-
user=self._user,
469-
additional_chat_options=self._additional_chat_options,
497+
default_options=default_options,
498+
context_providers=ctx_providers,
499+
middleware=mw,
470500
**self._kwargs,
471501
)
472502

@@ -755,31 +785,60 @@ def create_agent(
755785
``async with`` to ensure proper initialization and cleanup via the Agent's
756786
async context manager protocol.
757787
"""
788+
# Build default_options from model parameters
789+
opts: dict[str, Any] = {}
790+
if frequency_penalty is not None:
791+
opts["frequency_penalty"] = frequency_penalty
792+
if logit_bias is not None:
793+
opts["logit_bias"] = logit_bias
794+
if max_tokens is not None:
795+
opts["max_tokens"] = max_tokens
796+
if metadata is not None:
797+
opts["metadata"] = metadata
798+
if model_id is not None:
799+
opts["model"] = model_id
800+
if presence_penalty is not None:
801+
opts["presence_penalty"] = presence_penalty
802+
if response_format is not None:
803+
opts["response_format"] = response_format
804+
if seed is not None:
805+
opts["seed"] = seed
806+
if stop is not None:
807+
opts["stop"] = stop
808+
if store is not None:
809+
opts["store"] = store
810+
if temperature is not None:
811+
opts["temperature"] = temperature
812+
if tool_choice is not None:
813+
opts["tool_choice"] = tool_choice
814+
if top_p is not None:
815+
opts["top_p"] = top_p
816+
if user is not None:
817+
opts["user"] = user
818+
if additional_chat_options:
819+
opts.update(additional_chat_options)
820+
821+
default_options = ChatOptions(**opts) if opts else None
822+
823+
# Agent expects context_providers as a Sequence; wrap single instance in a list
824+
ctx_providers = context_providers
825+
if ctx_providers is not None and not isinstance(ctx_providers, list):
826+
ctx_providers = [ctx_providers]
827+
828+
# Agent expects middleware as a Sequence; wrap single instance in a list
829+
mw = middleware
830+
if mw is not None and not isinstance(mw, list):
831+
mw = [mw]
832+
758833
return Agent(
759-
chat_client=chat_client,
834+
chat_client,
760835
instructions=instructions,
761836
id=id,
762837
name=name,
763838
description=description,
764-
chat_message_store_factory=chat_message_store_factory,
765-
conversation_id=conversation_id,
766-
context_providers=context_providers,
767-
middleware=middleware,
768-
frequency_penalty=frequency_penalty,
769-
logit_bias=logit_bias,
770-
max_tokens=max_tokens,
771-
metadata=metadata,
772-
model_id=model_id,
773-
presence_penalty=presence_penalty,
774-
response_format=response_format,
775-
seed=seed,
776-
stop=stop,
777-
store=store,
778-
temperature=temperature,
779-
tool_choice=tool_choice,
780839
tools=tools,
781-
top_p=top_p,
782-
user=user,
783-
additional_chat_options=additional_chat_options,
840+
default_options=default_options,
841+
context_providers=ctx_providers,
842+
middleware=mw,
784843
**kwargs,
785844
)

src/processor/src/libs/agent_framework/agent_info.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
from typing import Any, Callable, MutableMapping, Sequence
77

8-
from agent_framework import FunctionTool
8+
from agent_framework import FunctionTool, MCPStdioTool, MCPStreamableHTTPTool
99
from jinja2 import Template
1010
from openai import BaseModel
1111
from pydantic import Field
1212

1313
from .agent_framework_helper import AgentFrameworkHelper, ClientType
1414

15+
ToolType = FunctionTool | MCPStreamableHTTPTool | MCPStdioTool | Callable[..., Any] | MutableMapping[str, Any]
16+
1517

1618
class AgentInfo(BaseModel):
1719
agent_name: str
@@ -21,10 +23,8 @@ class AgentInfo(BaseModel):
2123
agent_instruction: str | None = Field(default=None)
2224
agent_framework_helper: AgentFrameworkHelper | None = Field(default=None)
2325
tools: (
24-
FunctionTool
25-
| Callable[..., Any]
26-
| MutableMapping[str, Any]
27-
| Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]]
26+
ToolType
27+
| Sequence[ToolType]
2828
| None
2929
) = Field(default=None)
3030

src/processor/src/libs/agent_framework/azure_openai_response_retry.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,117 @@ def _bool(name: str, default: bool) -> bool:
325325
)
326326

327327

328+
def _get_content_items(message: Any) -> list[Any]:
329+
"""Return the list of content items from a message, or empty list."""
330+
contents = None
331+
if isinstance(message, dict):
332+
contents = message.get("contents") or message.get("content")
333+
else:
334+
contents = getattr(message, "contents", None) or getattr(message, "content", None)
335+
if isinstance(contents, list):
336+
return contents
337+
return []
338+
339+
340+
def _remove_orphan_tool_messages(messages: list[Any]) -> list[Any]:
341+
"""Remove messages with orphaned function_call or function_result items.
342+
343+
The Responses API requires every function_call in the input to have a
344+
corresponding function_call_output (function_result). If context trimming
345+
breaks these pairs, the API rejects the request.
346+
"""
347+
# Collect call_ids for function_calls and function_results
348+
call_ids_with_call: set[str] = set()
349+
call_ids_with_result: set[str] = set()
350+
351+
for m in messages:
352+
for item in _get_content_items(m):
353+
item_type = None
354+
call_id = None
355+
if isinstance(item, dict):
356+
item_type = item.get("type")
357+
call_id = item.get("call_id")
358+
else:
359+
item_type = getattr(item, "type", None)
360+
call_id = getattr(item, "call_id", None)
361+
if not call_id:
362+
continue
363+
if item_type == "function_call":
364+
call_ids_with_call.add(call_id)
365+
elif item_type == "function_result":
366+
call_ids_with_result.add(call_id)
367+
368+
# Identify orphaned call_ids
369+
orphaned_calls = call_ids_with_call - call_ids_with_result
370+
orphaned_results = call_ids_with_result - call_ids_with_call
371+
372+
if not orphaned_calls and not orphaned_results:
373+
return messages
374+
375+
logger.warning(
376+
"[AOAI_CTX_TRIM] removing orphaned tool messages: %d orphaned calls, %d orphaned results",
377+
len(orphaned_calls),
378+
len(orphaned_results),
379+
)
380+
381+
# Remove messages that ONLY contain orphaned tool items
382+
cleaned: list[Any] = []
383+
for m in messages:
384+
items = _get_content_items(m)
385+
if not items:
386+
cleaned.append(m)
387+
continue
388+
389+
has_orphan = False
390+
has_non_orphan = False
391+
for item in items:
392+
item_type = None
393+
call_id = None
394+
if isinstance(item, dict):
395+
item_type = item.get("type")
396+
call_id = item.get("call_id")
397+
else:
398+
item_type = getattr(item, "type", None)
399+
call_id = getattr(item, "call_id", None)
400+
if call_id and item_type == "function_call" and call_id in orphaned_calls:
401+
has_orphan = True
402+
elif call_id and item_type == "function_result" and call_id in orphaned_results:
403+
has_orphan = True
404+
else:
405+
has_non_orphan = True
406+
407+
if has_orphan and not has_non_orphan:
408+
# Message contains ONLY orphaned tool items — drop it entirely
409+
continue
410+
elif has_orphan and has_non_orphan:
411+
# Message has both orphan and non-orphan content.
412+
# Drop orphaned items if possible, keeping the rest.
413+
if isinstance(items, list) and not isinstance(m, dict):
414+
# Filter out orphaned content items from the message
415+
filtered = []
416+
for item in items:
417+
item_type = getattr(item, "type", None)
418+
call_id = getattr(item, "call_id", None)
419+
if call_id and item_type == "function_call" and call_id in orphaned_calls:
420+
continue
421+
if call_id and item_type == "function_result" and call_id in orphaned_results:
422+
continue
423+
filtered.append(item)
424+
if filtered:
425+
try:
426+
m.contents = filtered
427+
except Exception:
428+
pass
429+
cleaned.append(m)
430+
# else: drop message entirely if no content remains
431+
else:
432+
cleaned.append(m)
433+
else:
434+
cleaned.append(m)
435+
436+
return cleaned
437+
438+
328439
def _trim_messages(
329440
messages: MutableSequence[Any], *, cfg: ContextTrimConfig
330441
) -> list[Any]:
@@ -414,6 +525,11 @@ def _total_chars(msgs: list[Any]) -> int:
414525
break
415526
combined.pop(drop_index)
416527

528+
# Phase final: Remove orphaned tool call / tool result messages.
529+
# The Responses API requires every function_call to have a matching
530+
# function_call_output. Trimming may break these pairs.
531+
combined = _remove_orphan_tool_messages(combined)
532+
417533
return combined
418534

419535

@@ -539,12 +655,27 @@ def __init__(
539655
# Map legacy params to OpenAIChatClient params
540656
if deployment_name and "model" not in kwargs:
541657
kwargs["model"] = deployment_name
542-
if endpoint and "azure_endpoint" not in kwargs:
658+
if endpoint and not kwargs.get("azure_endpoint"):
543659
kwargs["azure_endpoint"] = endpoint
544660
if ad_token_provider and kwargs.get("credential") is None:
545661
kwargs["credential"] = ad_token_provider
546662

663+
# Remove None-valued keys that would conflict with env-based settings
664+
for k in list(kwargs):
665+
if kwargs[k] is None:
666+
del kwargs[k]
667+
547668
super().__init__(*args, **kwargs)
669+
670+
# OpenAIChatClient appends /v1/ to azure_endpoint but Azure AI Foundry
671+
# endpoints expect /openai/responses (without /v1/). Fix the base URL.
672+
if hasattr(self, "client") and self.client is not None:
673+
base = str(self.client.base_url)
674+
if "/openai/v1/" in base:
675+
import httpx
676+
corrected = base.replace("/openai/v1/", "/openai/")
677+
self.client._base_url = httpx.URL(corrected)
678+
548679
self._retry_config = retry_config or RateLimitRetryConfig.from_env()
549680
self._context_trim_config = ContextTrimConfig.from_env()
550681

0 commit comments

Comments
 (0)