Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
build_resumed_stream_debug_extra,
ensure_context_wrapper,
finalize_conversation_tracking,
get_unsent_tool_call_ids_for_interrupted_state,
input_guardrails_triggered,
resolve_processed_response,
resolve_resumed_context,
Expand Down Expand Up @@ -570,6 +571,7 @@ async def run(
generated_items=run_state._generated_items,
model_responses=run_state._model_responses,
session_items=session_input_items,
unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(run_state),
)

tool_use_tracker = AgentToolUseTracker()
Expand Down
37 changes: 37 additions & 0 deletions src/agents/run_internal/agent_runner_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Mapping
from typing import Any, cast

from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
Expand Down Expand Up @@ -49,6 +50,7 @@
"describe_run_state_step",
"ensure_context_wrapper",
"finalize_conversation_tracking",
"get_unsent_tool_call_ids_for_interrupted_state",
"input_guardrails_triggered",
"validate_session_conversation_settings",
"resolve_trace_settings",
Expand Down Expand Up @@ -184,6 +186,41 @@ def apply_resumed_conversation_settings(
return conversation_id, previous_response_id, auto_previous_response_id


def _extract_tool_call_id(raw: Any) -> str | None:
if isinstance(raw, Mapping):
candidate = raw.get("call_id") or raw.get("id")
else:
candidate = getattr(raw, "call_id", None) or getattr(raw, "id", None)
return candidate if isinstance(candidate, str) else None


def get_unsent_tool_call_ids_for_interrupted_state(run_state: RunState[Any] | None) -> set[str]:
"""Return tool call IDs whose local outputs belong to the current interruption."""
if run_state is None or not isinstance(run_state._current_step, NextStepInterruption):
return set()

processed_response = run_state._last_processed_response
if processed_response is None:
return set()

tool_call_ids: set[str] = set()
tool_run_groups = (
processed_response.handoffs,
processed_response.functions,
processed_response.computer_actions,
processed_response.custom_tool_calls,
processed_response.local_shell_calls,
processed_response.shell_calls,
processed_response.apply_patch_calls,
)
for tool_runs in tool_run_groups:
for tool_run in tool_runs:
call_id = _extract_tool_call_id(getattr(tool_run, "tool_call", None))
if call_id is not None:
tool_call_ids.add(call_id)
return tool_call_ids


def validate_session_conversation_settings(
session: Session | None,
*,
Expand Down
70 changes: 37 additions & 33 deletions src/agents/run_internal/oai_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ def _is_tool_search_item(item: Any) -> bool:
return item_type in {"tool_search_call", "tool_search_output"}


def _extract_call_id(item: Any) -> str | None:
"""Return a tool call id from mapping or object payloads."""
call_id = item.get("call_id") if isinstance(item, dict) else getattr(item, "call_id", None)
return call_id if isinstance(call_id, str) else None


def _has_output_payload(item: Any) -> bool:
"""Return True when an item carries a local tool output payload."""
return (isinstance(item, dict) and "output" in item) or hasattr(item, "output")


@dataclass
class OpenAIServerConversationTracker:
"""Track server-side conversation state for conversation-aware runs.
Expand Down Expand Up @@ -141,6 +152,7 @@ def hydrate_from_state(
generated_items: list[RunItem],
model_responses: list[ModelResponse],
session_items: list[TResponseInputItem] | None = None,
unsent_tool_call_ids: set[str] | None = None,
) -> None:
"""Seed tracking from prior state so resumed runs do not replay already-sent content.

Expand All @@ -151,6 +163,7 @@ def hydrate_from_state(
"""
if self.sent_initial_input:
return
unsent_tool_call_ids = unsent_tool_call_ids or set()

normalized_input = original_input
if isinstance(original_input, list):
Expand Down Expand Up @@ -189,13 +202,8 @@ def hydrate_from_state(
)
if item_id is not None:
self.server_item_ids.add(item_id)
call_id = (
output_item.get("call_id")
if isinstance(output_item, dict)
else getattr(output_item, "call_id", None)
)
has_output_payload = isinstance(output_item, dict) and "output" in output_item
has_output_payload = has_output_payload or hasattr(output_item, "output")
call_id = _extract_call_id(output_item)
has_output_payload = _has_output_payload(output_item)
if isinstance(call_id, str) and has_output_payload:
self.server_tool_call_ids.add(call_id)

Expand All @@ -209,13 +217,8 @@ def hydrate_from_state(
)
if item_id is not None:
self.server_item_ids.add(item_id)
call_id = (
item.get("call_id")
if isinstance(item, dict)
else getattr(item, "call_id", None)
)
has_output = isinstance(item, dict) and "output" in item
has_output = has_output or hasattr(item, "output")
call_id = _extract_call_id(item)
has_output = _has_output_payload(item)
if isinstance(call_id, str) and has_output:
self.server_tool_call_ids.add(call_id)
fp = _fingerprint_for_tracker(item)
Expand All @@ -237,10 +240,15 @@ def hydrate_from_state(

if isinstance(raw_item, dict):
item_id = _normalize_server_item_id(raw_item.get("id"))
call_id = raw_item.get("call_id")
has_output_payload = "output" in raw_item
has_output_payload = has_output_payload or hasattr(raw_item, "output")
call_id = _extract_call_id(raw_item)
has_output_payload = _has_output_payload(raw_item)
has_call_id = isinstance(call_id, str)
if (
isinstance(call_id, str)
and has_output_payload
and call_id in unsent_tool_call_ids
):
continue
should_mark = (
item_id is not None
or (has_call_id and (has_output_payload or is_tool_call_item))
Expand All @@ -266,9 +274,15 @@ def hydrate_from_state(
self.server_tool_call_ids.add(call_id)
else:
item_id = _normalize_server_item_id(getattr(raw_item, "id", None))
call_id = getattr(raw_item, "call_id", None)
has_output_payload = hasattr(raw_item, "output")
call_id = _extract_call_id(raw_item)
has_output_payload = _has_output_payload(raw_item)
has_call_id = isinstance(call_id, str)
if (
isinstance(call_id, str)
and has_output_payload
and call_id in unsent_tool_call_ids
):
continue
should_mark = (
item_id is not None
or (has_call_id and (has_output_payload or is_tool_call_item))
Expand Down Expand Up @@ -309,13 +323,8 @@ def track_server_items(self, model_response: ModelResponse | None) -> None:
)
if item_id is not None:
self.server_item_ids.add(item_id)
call_id = (
output_item.get("call_id")
if isinstance(output_item, dict)
else getattr(output_item, "call_id", None)
)
has_output_payload = isinstance(output_item, dict) and "output" in output_item
has_output_payload = has_output_payload or hasattr(output_item, "output")
call_id = _extract_call_id(output_item)
has_output_payload = _has_output_payload(output_item)
if isinstance(call_id, str) and has_output_payload:
self.server_tool_call_ids.add(call_id)
fp = _fingerprint_for_tracker(output_item)
Expand Down Expand Up @@ -445,13 +454,8 @@ def prepare_input(
if item_id is not None and item_id in self.server_item_ids:
continue

call_id = (
raw_item.get("call_id")
if isinstance(raw_item, dict)
else getattr(raw_item, "call_id", None)
)
has_output_payload = isinstance(raw_item, dict) and "output" in raw_item
has_output_payload = has_output_payload or hasattr(raw_item, "output")
call_id = _extract_call_id(raw_item)
has_output_payload = _has_output_payload(raw_item)
if (
isinstance(call_id, str)
and has_output_payload
Expand Down
2 changes: 2 additions & 0 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from .agent_runner_helpers import (
apply_resumed_conversation_settings,
attach_usage_to_span,
get_unsent_tool_call_ids_for_interrupted_state,
snapshot_usage,
usage_delta,
)
Expand Down Expand Up @@ -577,6 +578,7 @@ def _sync_conversation_tracking_from_tracker() -> None:
generated_items=run_state._generated_items,
model_responses=run_state._model_responses,
session_items=session_items,
unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(run_state),
)

streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))
Expand Down
126 changes: 125 additions & 1 deletion tests/test_server_conversation_tracker.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from types import SimpleNamespace
from typing import Any, cast

import pytest
from openai.types.responses import ResponseFunctionToolCall
from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool

from agents import Agent, HostedMCPTool
from agents.items import MCPListToolsItem, ModelResponse, RunItem, ToolCallItem, TResponseInputItem
from agents.items import (
MCPListToolsItem,
ModelResponse,
RunItem,
ToolApprovalItem,
ToolCallItem,
ToolCallOutputItem,
TResponseInputItem,
)
from agents.lifecycle import RunHooks
from agents.models.fake_id import FAKE_RESPONSES_ID
from agents.result import RunResultStreaming
from agents.run_config import ModelInputData, RunConfig
from agents.run_context import RunContextWrapper
from agents.run_internal.agent_bindings import bind_public_agent
from agents.run_internal.agent_runner_helpers import get_unsent_tool_call_ids_for_interrupted_state
from agents.run_internal.oai_conversation import OpenAIServerConversationTracker
from agents.run_internal.run_loop import get_new_response, run_single_turn_streamed
from agents.run_internal.run_steps import NextStepInterruption
from agents.run_internal.tool_use_tracker import AgentToolUseTracker
from agents.stream_events import RunItemStreamEvent
from agents.usage import Usage
Expand Down Expand Up @@ -85,6 +97,118 @@ def test_prepare_input_filters_items_seen_by_server_and_tool_calls() -> None:
assert tracker.remaining_initial_input is None


def test_hydrate_from_state_preserves_unsent_outputs_from_interrupted_turn() -> None:
agent = Agent(name="test")
cleanup1_call = ResponseFunctionToolCall(
id="fc_001",
type="function_call",
call_id="call_CLEANUP1",
name="run_cleanup",
arguments='{"target": "temp_files"}',
status="completed",
)
diagnostic_call = ResponseFunctionToolCall(
id="fc_002",
type="function_call",
call_id="call_DIAG",
name="run_diagnostic",
arguments='{"check_name": "thermal"}',
status="completed",
)
cleanup2_call = ResponseFunctionToolCall(
id="fc_003",
type="function_call",
call_id="call_CLEANUP2",
name="run_cleanup",
arguments='{"target": "winsxs_cache"}',
status="completed",
)
model_response = ModelResponse(
output=[cleanup1_call, diagnostic_call, cleanup2_call],
usage=Usage(),
response_id="resp_002",
)
diagnostic_output = ToolCallOutputItem(
agent=agent,
raw_item={
"type": "function_call_output",
"call_id": "call_DIAG",
"output": "Diagnostic completed.",
},
output="Diagnostic completed.",
)
generated_items: list[RunItem] = [
ToolCallItem(agent=agent, raw_item=cleanup1_call),
ToolCallItem(agent=agent, raw_item=diagnostic_call),
ToolCallItem(agent=agent, raw_item=cleanup2_call),
diagnostic_output,
ToolApprovalItem(agent=agent, raw_item=cleanup1_call, tool_name="run_cleanup"),
ToolApprovalItem(agent=agent, raw_item=cleanup2_call, tool_name="run_cleanup"),
]
interrupted_state = SimpleNamespace(
_current_step=NextStepInterruption(interruptions=[]),
_last_processed_response=SimpleNamespace(
handoffs=[],
functions=[
SimpleNamespace(tool_call=cleanup1_call),
SimpleNamespace(tool_call=diagnostic_call),
SimpleNamespace(tool_call=cleanup2_call),
],
computer_actions=[],
custom_tool_calls=[],
local_shell_calls=[],
shell_calls=[],
apply_patch_calls=[],
),
)

tracker = OpenAIServerConversationTracker(previous_response_id="resp_002")
tracker.hydrate_from_state(
original_input="Run cleanup, diagnostics, and cleanup.",
generated_items=generated_items,
model_responses=[model_response],
unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(
cast(Any, interrupted_state)
),
)

assert "call_DIAG" not in tracker.server_tool_call_ids

prepared = tracker.prepare_input(
"Run cleanup, diagnostics, and cleanup.",
[
ToolCallItem(agent=agent, raw_item=cleanup1_call),
ToolCallItem(agent=agent, raw_item=diagnostic_call),
ToolCallItem(agent=agent, raw_item=cleanup2_call),
diagnostic_output,
ToolCallOutputItem(
agent=agent,
raw_item={
"type": "function_call_output",
"call_id": "call_CLEANUP1",
"output": "Tool call not approved.",
},
output="Tool call not approved.",
),
ToolCallOutputItem(
agent=agent,
raw_item={
"type": "function_call_output",
"call_id": "call_CLEANUP2",
"output": "Tool call not approved.",
},
output="Tool call not approved.",
),
],
)

assert [
item.get("call_id")
for item in prepared
if isinstance(item, dict) and item.get("type") == "function_call_output"
] == ["call_DIAG", "call_CLEANUP1", "call_CLEANUP2"]


def test_hydrate_from_state_does_not_track_string_initial_input_by_object_identity() -> None:
tracker = OpenAIServerConversationTracker(
conversation_id="conv-init-string", previous_response_id=None
Expand Down
Loading