Skip to content

Commit 34ff848

Browse files
authored
feat: #2669 add opt-in reasoning content replay for chat completion models (#2670)
1 parent f0df572 commit 34ff848

File tree

6 files changed

+581
-19
lines changed

6 files changed

+581
-19
lines changed

src/agents/extensions/models/litellm_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from ...models.fake_id import FAKE_RESPONSES_ID
5050
from ...models.interface import Model, ModelTracing
5151
from ...models.openai_responses import Converter as OpenAIResponsesConverter
52+
from ...models.reasoning_content_replay import ShouldReplayReasoningContent
5253
from ...retry import ModelRetryAdvice, ModelRetryAdviceRequest
5354
from ...tool import Tool
5455
from ...tracing import generation_span
@@ -146,10 +147,12 @@ def __init__(
146147
model: str,
147148
base_url: str | None = None,
148149
api_key: str | None = None,
150+
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
149151
):
150152
self.model = model
151153
self.base_url = base_url
152154
self.api_key = api_key
155+
self.should_replay_reasoning_content = should_replay_reasoning_content
153156

154157
def get_retry_advice(self, request: ModelRetryAdviceRequest) -> ModelRetryAdvice | None:
155158
# LiteLLM exceptions mirror OpenAI-style status/header fields.
@@ -383,9 +386,11 @@ async def _fetch_response(
383386

384387
converted_messages = Converter.items_to_messages(
385388
input,
389+
base_url=self.base_url,
386390
preserve_thinking_blocks=preserve_thinking_blocks,
387391
preserve_tool_output_all_content=True,
388392
model=self.model,
393+
should_replay_reasoning_content=self.should_replay_reasoning_content,
389394
)
390395

391396
# Fix message ordering: reorder to ensure tool_use comes before tool_result.

src/agents/models/chatcmpl_converter.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
ensure_tool_choice_supports_backend,
5656
)
5757
from .fake_id import FAKE_RESPONSES_ID
58+
from .reasoning_content_replay import (
59+
ReasoningContentReplayContext,
60+
ReasoningContentSource,
61+
ShouldReplayReasoningContent,
62+
default_should_replay_reasoning_content,
63+
)
5864

5965
ResponseInputContentWithAudioParam = Union[
6066
ResponseInputContentParam,
@@ -422,6 +428,8 @@ def items_to_messages(
422428
model: str | None = None,
423429
preserve_thinking_blocks: bool = False,
424430
preserve_tool_output_all_content: bool = False,
431+
base_url: str | None = None,
432+
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
425433
) -> list[ChatCompletionMessageParam]:
426434
"""
427435
Convert a sequence of 'Item' objects into a list of ChatCompletionMessageParam.
@@ -441,6 +449,12 @@ def items_to_messages(
441449
When True, all content types including images are preserved. This is useful
442450
for model providers (e.g. Anthropic via LiteLLM) that support processing
443451
non-text content in tool results.
452+
base_url: The request base URL, if the caller knows the concrete endpoint.
453+
This is used by reasoning-content replay hooks to distinguish direct
454+
provider calls from proxy or gateway requests.
455+
should_replay_reasoning_content: Optional hook that decides whether a
456+
reasoning item should be replayed into the next assistant message as
457+
`reasoning_content`.
444458
445459
Rules:
446460
- EasyInputMessage or InputMessage (role=user) => ChatCompletionUserMessageParam
@@ -464,8 +478,9 @@ def items_to_messages(
464478
current_assistant_msg: ChatCompletionAssistantMessageParam | None = None
465479
pending_thinking_blocks: list[dict[str, str]] | None = None
466480
pending_reasoning_content: str | None = None # For DeepSeek reasoning_content
481+
normalized_base_url = base_url.rstrip("/") if base_url is not None else None
467482

468-
def flush_assistant_message() -> None:
483+
def flush_assistant_message(*, clear_pending_reasoning_content: bool = True) -> None:
469484
nonlocal current_assistant_msg, pending_reasoning_content
470485
if current_assistant_msg is not None:
471486
# The API doesn't support empty arrays for tool_calls
@@ -475,7 +490,15 @@ def flush_assistant_message() -> None:
475490
pending_reasoning_content = None
476491
result.append(current_assistant_msg)
477492
current_assistant_msg = None
478-
else:
493+
elif clear_pending_reasoning_content:
494+
pending_reasoning_content = None
495+
496+
def apply_pending_reasoning_content(
497+
assistant_msg: ChatCompletionAssistantMessageParam,
498+
) -> None:
499+
nonlocal pending_reasoning_content
500+
if pending_reasoning_content:
501+
assistant_msg["reasoning_content"] = pending_reasoning_content # type: ignore[typeddict-unknown-key]
479502
pending_reasoning_content = None
480503

481504
def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
@@ -485,6 +508,8 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
485508
current_assistant_msg["content"] = None
486509
current_assistant_msg["tool_calls"] = []
487510

511+
apply_pending_reasoning_content(current_assistant_msg)
512+
488513
return current_assistant_msg
489514

490515
for item in items:
@@ -553,7 +578,9 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
553578

554579
# 3) response output message => assistant
555580
elif resp_msg := cls.maybe_response_output_message(item):
556-
flush_assistant_message()
581+
# A reasoning item can be followed by an assistant message and then tool calls
582+
# in the same turn, so preserve pending reasoning_content across this flush.
583+
flush_assistant_message(clear_pending_reasoning_content=False)
557584
new_asst = ChatCompletionAssistantMessageParam(role="assistant")
558585
contents = resp_msg["content"]
559586

@@ -594,6 +621,7 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
594621
pending_thinking_blocks = None # Clear after using
595622

596623
new_asst["tool_calls"] = []
624+
apply_pending_reasoning_content(new_asst)
597625
current_assistant_msg = new_asst
598626

599627
# 4) function/file-search calls => attach to assistant
@@ -619,11 +647,6 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
619647
elif func_call := cls.maybe_function_tool_call(item):
620648
asst = ensure_assistant_message()
621649

622-
# If we have pending reasoning content for DeepSeek, add it to the assistant message
623-
if pending_reasoning_content:
624-
asst["reasoning_content"] = pending_reasoning_content # type: ignore[typeddict-unknown-key]
625-
pending_reasoning_content = None # Clear after using
626-
627650
# If we have pending thinking blocks, use them as the content
628651
# This is required for Anthropic API tool calls with interleaved thinking
629652
if pending_thinking_blocks:
@@ -708,6 +731,7 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
708731

709732
item_provider_data: dict[str, Any] = reasoning_item.get("provider_data", {}) # type: ignore[assignment]
710733
item_model = item_provider_data.get("model", "")
734+
should_replay = False
711735

712736
if (
713737
model
@@ -740,17 +764,23 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
740764
# This preserves the original behavior
741765
pending_thinking_blocks = reconstructed_thinking_blocks
742766

743-
# DeepSeek requires reasoning_content field in assistant messages with tool calls
744-
# Items may not all originate from DeepSeek, so need to check for model match.
745-
# For backward compatibility, if provider_data is missing, ignore the check.
746-
elif (
747-
model
748-
and "deepseek" in model.lower()
749-
and (
750-
(item_model and "deepseek" in item_model.lower())
751-
or item_provider_data == {}
767+
if model is not None:
768+
replay_context = ReasoningContentReplayContext(
769+
model=model,
770+
base_url=normalized_base_url,
771+
reasoning=ReasoningContentSource(
772+
item=reasoning_item,
773+
origin_model=item_model or None,
774+
provider_data=item_provider_data,
775+
),
752776
)
753-
):
777+
should_replay = (
778+
should_replay_reasoning_content(replay_context)
779+
if should_replay_reasoning_content is not None
780+
else default_should_replay_reasoning_content(replay_context)
781+
)
782+
783+
if should_replay:
754784
summary_items = reasoning_item.get("summary", [])
755785
if summary_items:
756786
reasoning_texts = []

src/agents/models/openai_chatcompletions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .fake_id import FAKE_RESPONSES_ID
4040
from .interface import Model, ModelTracing
4141
from .openai_responses import Converter as OpenAIResponsesConverter
42+
from .reasoning_content_replay import ShouldReplayReasoningContent
4243

4344
if TYPE_CHECKING:
4445
from ..model_settings import ModelSettings
@@ -53,9 +54,11 @@ def __init__(
5354
self,
5455
model: str | ChatModel,
5556
openai_client: AsyncOpenAI,
57+
should_replay_reasoning_content: ShouldReplayReasoningContent | None = None,
5658
) -> None:
5759
self.model = model
5860
self._client = openai_client
61+
self.should_replay_reasoning_content = should_replay_reasoning_content
5962

6063
def _non_null_or_omit(self, value: Any) -> Any:
6164
return value if value is not None else omit
@@ -314,7 +317,12 @@ async def _fetch_response(
314317
prompt: ResponsePromptParam | None = None,
315318
) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]:
316319
self._validate_official_openai_input_content_types(input)
317-
converted_messages = Converter.items_to_messages(input, model=self.model)
320+
converted_messages = Converter.items_to_messages(
321+
input,
322+
model=self.model,
323+
base_url=str(self._client.base_url),
324+
should_replay_reasoning_content=self.should_replay_reasoning_content,
325+
)
318326

319327
if system_instructions:
320328
converted_messages.insert(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Mapping
4+
from dataclasses import dataclass
5+
from typing import Any, Callable
6+
7+
8+
@dataclass
9+
class ReasoningContentSource:
10+
"""The reasoning item being considered for replay into the next request."""
11+
12+
item: Any
13+
"""The raw reasoning item."""
14+
15+
origin_model: str | None
16+
"""The model that originally produced the reasoning item, if known."""
17+
18+
provider_data: Mapping[str, Any]
19+
"""Provider-specific metadata captured on the reasoning item."""
20+
21+
22+
@dataclass
23+
class ReasoningContentReplayContext:
24+
"""Context passed to reasoning-content replay hooks."""
25+
26+
model: str
27+
"""The model that will receive the next Chat Completions request."""
28+
29+
base_url: str | None
30+
"""The request base URL, if the SDK knows the concrete endpoint."""
31+
32+
reasoning: ReasoningContentSource
33+
"""The reasoning item candidate being evaluated for replay."""
34+
35+
36+
ShouldReplayReasoningContent = Callable[[ReasoningContentReplayContext], bool]
37+
38+
39+
def default_should_replay_reasoning_content(context: ReasoningContentReplayContext) -> bool:
40+
"""Return whether the SDK should replay reasoning content by default."""
41+
42+
if "deepseek" not in context.model.lower():
43+
return False
44+
45+
origin_model = context.reasoning.origin_model
46+
# Replay only when the current request targets DeepSeek and the reasoning item either
47+
# came from a DeepSeek model or predates provider tracking. This avoids mixing reasoning
48+
# content from a different model family into the DeepSeek assistant message.
49+
return (
50+
origin_model is not None and "deepseek" in origin_model.lower()
51+
) or context.reasoning.provider_data == {}
52+
53+
54+
__all__ = [
55+
"ReasoningContentReplayContext",
56+
"ReasoningContentSource",
57+
"ShouldReplayReasoningContent",
58+
"default_should_replay_reasoning_content",
59+
]

0 commit comments

Comments
 (0)