Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from ...models.fake_id import FAKE_RESPONSES_ID
from ...models.interface import Model, ModelTracing
from ...models.openai_responses import Converter as OpenAIResponsesConverter
from ...models.reasoning_content_replay import ShouldReplayReasoningContent
from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest
from ...tool import Tool
from ...tracing import generation_span
Expand Down Expand Up @@ -146,10 +147,12 @@ def __init__(
model: str,
base_url: str | None = None,
api_key: str | None = None,
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
):
self.model = model
self.base_url = base_url
self.api_key = api_key
self.should_replay_reasoning_content = should_replay_reasoning_content

def get_retry_advice(self, request: ModelRetryAdviceRequest) -> ModelRetryAdvice | None:
# LiteLLM exceptions mirror OpenAI-style status/header fields.
Expand Down Expand Up @@ -383,9 +386,11 @@ async def _fetch_response(

converted_messages = Converter.items_to_messages(
input,
base_url=self.base_url,
preserve_thinking_blocks=preserve_thinking_blocks,
preserve_tool_output_all_content=True,
model=self.model,
should_replay_reasoning_content=self.should_replay_reasoning_content,
)

# Fix message ordering: reorder to ensure tool_use comes before tool_result.
Expand Down
43 changes: 30 additions & 13 deletions src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
ensure_tool_choice_supports_backend,
)
from .fake_id import FAKE_RESPONSES_ID
from .reasoning_content_replay import (
ReasoningContentReplayContext,
ReasoningContentSource,
ShouldReplayReasoningContent,
default_should_replay_reasoning_content,
)

ResponseInputContentWithAudioParam = Union[
ResponseInputContentParam,
Expand Down Expand Up @@ -420,8 +426,10 @@ def items_to_messages(
cls,
items: str | Iterable[TResponseInputItem],
model: str | None = None,
base_url: str | None = None,
preserve_thinking_blocks: bool = False,
preserve_tool_output_all_content: bool = False,
Comment thread
seratch marked this conversation as resolved.
Outdated
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
) -> list[ChatCompletionMessageParam]:
"""
Convert a sequence of 'Item' objects into a list of ChatCompletionMessageParam.
Expand Down Expand Up @@ -465,7 +473,7 @@ def items_to_messages(
pending_thinking_blocks: list[dict[str, str]] | None = None
pending_reasoning_content: str | None = None # For DeepSeek reasoning_content

def flush_assistant_message() -> None:
def flush_assistant_message(*, clear_pending_reasoning_content: bool = True) -> None:
nonlocal current_assistant_msg, pending_reasoning_content
if current_assistant_msg is not None:
# The API doesn't support empty arrays for tool_calls
Expand All @@ -475,7 +483,7 @@ def flush_assistant_message() -> None:
pending_reasoning_content = None
result.append(current_assistant_msg)
current_assistant_msg = None
else:
elif clear_pending_reasoning_content:
pending_reasoning_content = None

def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
Expand Down Expand Up @@ -553,7 +561,9 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:

# 3) response output message => assistant
elif resp_msg := cls.maybe_response_output_message(item):
flush_assistant_message()
# A reasoning item can be followed by an assistant message and then tool calls
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a different bug I found during the feature addition this time.

# in the same turn, so preserve pending reasoning_content across this flush.
flush_assistant_message(clear_pending_reasoning_content=False)
new_asst = ChatCompletionAssistantMessageParam(role="assistant")
contents = resp_msg["content"]

Expand Down Expand Up @@ -708,6 +718,7 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:

item_provider_data: dict[str, Any] = reasoning_item.get("provider_data", {}) # type: ignore[assignment]
item_model = item_provider_data.get("model", "")
should_replay = False

if (
model
Expand Down Expand Up @@ -740,17 +751,23 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
# This preserves the original behavior
pending_thinking_blocks = reconstructed_thinking_blocks

# DeepSeek requires reasoning_content field in assistant messages with tool calls
# Items may not all originate from DeepSeek, so need to check for model match.
# For backward compatibility, if provider_data is missing, ignore the check.
elif (
model
and "deepseek" in model.lower()
and (
(item_model and "deepseek" in item_model.lower())
or item_provider_data == {}
elif model is not None:
replay_context = ReasoningContentReplayContext(
model=model,
base_url=base_url,
reasoning=ReasoningContentSource(
item=reasoning_item,
origin_model=item_model or None,
provider_data=item_provider_data,
),
)
):
should_replay = (
should_replay_reasoning_content(replay_context)
if should_replay_reasoning_content is not None
else default_should_replay_reasoning_content(replay_context)
)

if should_replay:
summary_items = reasoning_item.get("summary", [])
if summary_items:
reasoning_texts = []
Expand Down
10 changes: 9 additions & 1 deletion src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .fake_id import FAKE_RESPONSES_ID
from .interface import Model, ModelTracing
from .openai_responses import Converter as OpenAIResponsesConverter
from .reasoning_content_replay import ShouldReplayReasoningContent

if TYPE_CHECKING:
from ..model_settings import ModelSettings
Expand All @@ -53,9 +54,11 @@ def __init__(
self,
model: str | ChatModel,
openai_client: AsyncOpenAI,
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
) -> None:
self.model = model
self._client = openai_client
self.should_replay_reasoning_content = should_replay_reasoning_content

def _non_null_or_omit(self, value: Any) -> Any:
return value if value is not None else omit
Expand Down Expand Up @@ -314,7 +317,12 @@ async def _fetch_response(
prompt: ResponsePromptParam | None = None,
) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]:
self._validate_official_openai_input_content_types(input)
converted_messages = Converter.items_to_messages(input, model=self.model)
converted_messages = Converter.items_to_messages(
input,
model=self.model,
base_url=str(self._client.base_url),
should_replay_reasoning_content=self.should_replay_reasoning_content,
)

if system_instructions:
converted_messages.insert(
Expand Down
59 changes: 59 additions & 0 deletions src/agents/models/reasoning_content_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Callable


@dataclass
class ReasoningContentSource:
"""The reasoning item being considered for replay into the next request."""

item: Any
"""The raw reasoning item."""

origin_model: str | None
"""The model that originally produced the reasoning item, if known."""

provider_data: Mapping[str, Any]
"""Provider-specific metadata captured on the reasoning item."""


@dataclass
class ReasoningContentReplayContext:
"""Context passed to reasoning-content replay hooks."""

model: str
"""The model that will receive the next Chat Completions request."""

base_url: str | None
"""The request base URL, if the SDK knows the concrete endpoint."""

reasoning: ReasoningContentSource
"""The reasoning item candidate being evaluated for replay."""


ShouldReplayReasoningContent = Callable[[ReasoningContentReplayContext], bool]


def default_should_replay_reasoning_content(context: ReasoningContentReplayContext) -> bool:
"""Return whether the SDK should replay reasoning content by default."""

if "deepseek" not in context.model.lower():
return False

origin_model = context.reasoning.origin_model
# Replay only when the current request targets DeepSeek and the reasoning item either
# came from a DeepSeek model or predates provider tracking. This avoids mixing reasoning
# content from a different model family into the DeepSeek assistant message.
return (
origin_model is not None and "deepseek" in origin_model.lower()
) or context.reasoning.provider_data == {}


__all__ = [
"ReasoningContentReplayContext",
"ReasoningContentSource",
"ShouldReplayReasoningContent",
"default_should_replay_reasoning_content",
]
Loading
Loading