From 60bafb536c2f9b073f6ef65ae891032fa95025a3 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Thu, 23 Apr 2026 09:44:46 +0900 Subject: [PATCH] fix: #3004 serve HITL resume tool outputs --- src/agents/run.py | 2 + .../run_internal/agent_runner_helpers.py | 37 +++++ src/agents/run_internal/oai_conversation.py | 70 +++++----- src/agents/run_internal/run_loop.py | 2 + tests/test_server_conversation_tracker.py | 126 +++++++++++++++++- 5 files changed, 203 insertions(+), 34 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index f116cc1fdd..68fa27b3bb 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -52,6 +52,7 @@ build_resumed_stream_debug_extra, ensure_context_wrapper, finalize_conversation_tracking, + get_unsent_tool_call_ids_for_interrupted_state, input_guardrails_triggered, resolve_processed_response, resolve_resumed_context, @@ -570,6 +571,7 @@ async def run( generated_items=run_state._generated_items, model_responses=run_state._model_responses, session_items=session_input_items, + unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(run_state), ) tool_use_tracker = AgentToolUseTracker() diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index e79f7ba656..a1115b5a1e 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Mapping from typing import Any, cast from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @@ -49,6 +50,7 @@ "describe_run_state_step", "ensure_context_wrapper", "finalize_conversation_tracking", + "get_unsent_tool_call_ids_for_interrupted_state", "input_guardrails_triggered", "validate_session_conversation_settings", "resolve_trace_settings", @@ -184,6 +186,41 @@ def apply_resumed_conversation_settings( return conversation_id, previous_response_id, auto_previous_response_id +def _extract_tool_call_id(raw: Any) -> str | None: + if isinstance(raw, Mapping): + candidate = raw.get("call_id") or raw.get("id") + else: + candidate = getattr(raw, "call_id", None) or getattr(raw, "id", None) + return candidate if isinstance(candidate, str) else None + + +def get_unsent_tool_call_ids_for_interrupted_state(run_state: RunState[Any] | None) -> set[str]: + """Return tool call IDs whose local outputs belong to the current interruption.""" + if run_state is None or not isinstance(run_state._current_step, NextStepInterruption): + return set() + + processed_response = run_state._last_processed_response + if processed_response is None: + return set() + + tool_call_ids: set[str] = set() + tool_run_groups = ( + processed_response.handoffs, + processed_response.functions, + processed_response.computer_actions, + processed_response.custom_tool_calls, + processed_response.local_shell_calls, + processed_response.shell_calls, + processed_response.apply_patch_calls, + ) + for tool_runs in tool_run_groups: + for tool_run in tool_runs: + call_id = _extract_tool_call_id(getattr(tool_run, "tool_call", None)) + if call_id is not None: + tool_call_ids.add(call_id) + return tool_call_ids + + def validate_session_conversation_settings( session: Session | None, *, diff --git a/src/agents/run_internal/oai_conversation.py b/src/agents/run_internal/oai_conversation.py index 233898d5cf..84d638f74e 100644 --- a/src/agents/run_internal/oai_conversation.py +++ b/src/agents/run_internal/oai_conversation.py @@ -84,6 +84,17 @@ def _is_tool_search_item(item: Any) -> bool: return item_type in {"tool_search_call", "tool_search_output"} +def _extract_call_id(item: Any) -> str | None: + """Return a tool call id from mapping or object payloads.""" + call_id = item.get("call_id") if isinstance(item, dict) else getattr(item, "call_id", None) + return call_id if isinstance(call_id, str) else None + + +def _has_output_payload(item: Any) -> bool: + """Return True when an item carries a local tool output payload.""" + return (isinstance(item, dict) and "output" in item) or hasattr(item, "output") + + @dataclass class OpenAIServerConversationTracker: """Track server-side conversation state for conversation-aware runs. @@ -141,6 +152,7 @@ def hydrate_from_state( generated_items: list[RunItem], model_responses: list[ModelResponse], session_items: list[TResponseInputItem] | None = None, + unsent_tool_call_ids: set[str] | None = None, ) -> None: """Seed tracking from prior state so resumed runs do not replay already-sent content. @@ -151,6 +163,7 @@ def hydrate_from_state( """ if self.sent_initial_input: return + unsent_tool_call_ids = unsent_tool_call_ids or set() normalized_input = original_input if isinstance(original_input, list): @@ -189,13 +202,8 @@ def hydrate_from_state( ) if item_id is not None: self.server_item_ids.add(item_id) - call_id = ( - output_item.get("call_id") - if isinstance(output_item, dict) - else getattr(output_item, "call_id", None) - ) - has_output_payload = isinstance(output_item, dict) and "output" in output_item - has_output_payload = has_output_payload or hasattr(output_item, "output") + call_id = _extract_call_id(output_item) + has_output_payload = _has_output_payload(output_item) if isinstance(call_id, str) and has_output_payload: self.server_tool_call_ids.add(call_id) @@ -209,13 +217,8 @@ def hydrate_from_state( ) if item_id is not None: self.server_item_ids.add(item_id) - call_id = ( - item.get("call_id") - if isinstance(item, dict) - else getattr(item, "call_id", None) - ) - has_output = isinstance(item, dict) and "output" in item - has_output = has_output or hasattr(item, "output") + call_id = _extract_call_id(item) + has_output = _has_output_payload(item) if isinstance(call_id, str) and has_output: self.server_tool_call_ids.add(call_id) fp = _fingerprint_for_tracker(item) @@ -237,10 +240,15 @@ def hydrate_from_state( if isinstance(raw_item, dict): item_id = _normalize_server_item_id(raw_item.get("id")) - call_id = raw_item.get("call_id") - has_output_payload = "output" in raw_item - has_output_payload = has_output_payload or hasattr(raw_item, "output") + call_id = _extract_call_id(raw_item) + has_output_payload = _has_output_payload(raw_item) has_call_id = isinstance(call_id, str) + if ( + isinstance(call_id, str) + and has_output_payload + and call_id in unsent_tool_call_ids + ): + continue should_mark = ( item_id is not None or (has_call_id and (has_output_payload or is_tool_call_item)) @@ -266,9 +274,15 @@ def hydrate_from_state( self.server_tool_call_ids.add(call_id) else: item_id = _normalize_server_item_id(getattr(raw_item, "id", None)) - call_id = getattr(raw_item, "call_id", None) - has_output_payload = hasattr(raw_item, "output") + call_id = _extract_call_id(raw_item) + has_output_payload = _has_output_payload(raw_item) has_call_id = isinstance(call_id, str) + if ( + isinstance(call_id, str) + and has_output_payload + and call_id in unsent_tool_call_ids + ): + continue should_mark = ( item_id is not None or (has_call_id and (has_output_payload or is_tool_call_item)) @@ -309,13 +323,8 @@ def track_server_items(self, model_response: ModelResponse | None) -> None: ) if item_id is not None: self.server_item_ids.add(item_id) - call_id = ( - output_item.get("call_id") - if isinstance(output_item, dict) - else getattr(output_item, "call_id", None) - ) - has_output_payload = isinstance(output_item, dict) and "output" in output_item - has_output_payload = has_output_payload or hasattr(output_item, "output") + call_id = _extract_call_id(output_item) + has_output_payload = _has_output_payload(output_item) if isinstance(call_id, str) and has_output_payload: self.server_tool_call_ids.add(call_id) fp = _fingerprint_for_tracker(output_item) @@ -445,13 +454,8 @@ def prepare_input( if item_id is not None and item_id in self.server_item_ids: continue - call_id = ( - raw_item.get("call_id") - if isinstance(raw_item, dict) - else getattr(raw_item, "call_id", None) - ) - has_output_payload = isinstance(raw_item, dict) and "output" in raw_item - has_output_payload = has_output_payload or hasattr(raw_item, "output") + call_id = _extract_call_id(raw_item) + has_output_payload = _has_output_payload(raw_item) if ( isinstance(call_id, str) and has_output_payload diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 0de58b95b6..039088ecb6 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -86,6 +86,7 @@ from .agent_runner_helpers import ( apply_resumed_conversation_settings, attach_usage_to_span, + get_unsent_tool_call_ids_for_interrupted_state, snapshot_usage, usage_delta, ) @@ -577,6 +578,7 @@ def _sync_conversation_tracking_from_tracker() -> None: generated_items=run_state._generated_items, model_responses=run_state._model_responses, session_items=session_items, + unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(run_state), ) streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) diff --git a/tests/test_server_conversation_tracker.py b/tests/test_server_conversation_tracker.py index cbe533c69e..703e2c6824 100644 --- a/tests/test_server_conversation_tracker.py +++ b/tests/test_server_conversation_tracker.py @@ -1,18 +1,30 @@ +from types import SimpleNamespace from typing import Any, cast import pytest +from openai.types.responses import ResponseFunctionToolCall from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool from agents import Agent, HostedMCPTool -from agents.items import MCPListToolsItem, ModelResponse, RunItem, ToolCallItem, TResponseInputItem +from agents.items import ( + MCPListToolsItem, + ModelResponse, + RunItem, + ToolApprovalItem, + ToolCallItem, + ToolCallOutputItem, + TResponseInputItem, +) from agents.lifecycle import RunHooks from agents.models.fake_id import FAKE_RESPONSES_ID from agents.result import RunResultStreaming from agents.run_config import ModelInputData, RunConfig from agents.run_context import RunContextWrapper from agents.run_internal.agent_bindings import bind_public_agent +from agents.run_internal.agent_runner_helpers import get_unsent_tool_call_ids_for_interrupted_state from agents.run_internal.oai_conversation import OpenAIServerConversationTracker from agents.run_internal.run_loop import get_new_response, run_single_turn_streamed +from agents.run_internal.run_steps import NextStepInterruption from agents.run_internal.tool_use_tracker import AgentToolUseTracker from agents.stream_events import RunItemStreamEvent from agents.usage import Usage @@ -85,6 +97,118 @@ def test_prepare_input_filters_items_seen_by_server_and_tool_calls() -> None: assert tracker.remaining_initial_input is None +def test_hydrate_from_state_preserves_unsent_outputs_from_interrupted_turn() -> None: + agent = Agent(name="test") + cleanup1_call = ResponseFunctionToolCall( + id="fc_001", + type="function_call", + call_id="call_CLEANUP1", + name="run_cleanup", + arguments='{"target": "temp_files"}', + status="completed", + ) + diagnostic_call = ResponseFunctionToolCall( + id="fc_002", + type="function_call", + call_id="call_DIAG", + name="run_diagnostic", + arguments='{"check_name": "thermal"}', + status="completed", + ) + cleanup2_call = ResponseFunctionToolCall( + id="fc_003", + type="function_call", + call_id="call_CLEANUP2", + name="run_cleanup", + arguments='{"target": "winsxs_cache"}', + status="completed", + ) + model_response = ModelResponse( + output=[cleanup1_call, diagnostic_call, cleanup2_call], + usage=Usage(), + response_id="resp_002", + ) + diagnostic_output = ToolCallOutputItem( + agent=agent, + raw_item={ + "type": "function_call_output", + "call_id": "call_DIAG", + "output": "Diagnostic completed.", + }, + output="Diagnostic completed.", + ) + generated_items: list[RunItem] = [ + ToolCallItem(agent=agent, raw_item=cleanup1_call), + ToolCallItem(agent=agent, raw_item=diagnostic_call), + ToolCallItem(agent=agent, raw_item=cleanup2_call), + diagnostic_output, + ToolApprovalItem(agent=agent, raw_item=cleanup1_call, tool_name="run_cleanup"), + ToolApprovalItem(agent=agent, raw_item=cleanup2_call, tool_name="run_cleanup"), + ] + interrupted_state = SimpleNamespace( + _current_step=NextStepInterruption(interruptions=[]), + _last_processed_response=SimpleNamespace( + handoffs=[], + functions=[ + SimpleNamespace(tool_call=cleanup1_call), + SimpleNamespace(tool_call=diagnostic_call), + SimpleNamespace(tool_call=cleanup2_call), + ], + computer_actions=[], + custom_tool_calls=[], + local_shell_calls=[], + shell_calls=[], + apply_patch_calls=[], + ), + ) + + tracker = OpenAIServerConversationTracker(previous_response_id="resp_002") + tracker.hydrate_from_state( + original_input="Run cleanup, diagnostics, and cleanup.", + generated_items=generated_items, + model_responses=[model_response], + unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state( + cast(Any, interrupted_state) + ), + ) + + assert "call_DIAG" not in tracker.server_tool_call_ids + + prepared = tracker.prepare_input( + "Run cleanup, diagnostics, and cleanup.", + [ + ToolCallItem(agent=agent, raw_item=cleanup1_call), + ToolCallItem(agent=agent, raw_item=diagnostic_call), + ToolCallItem(agent=agent, raw_item=cleanup2_call), + diagnostic_output, + ToolCallOutputItem( + agent=agent, + raw_item={ + "type": "function_call_output", + "call_id": "call_CLEANUP1", + "output": "Tool call not approved.", + }, + output="Tool call not approved.", + ), + ToolCallOutputItem( + agent=agent, + raw_item={ + "type": "function_call_output", + "call_id": "call_CLEANUP2", + "output": "Tool call not approved.", + }, + output="Tool call not approved.", + ), + ], + ) + + assert [ + item.get("call_id") + for item in prepared + if isinstance(item, dict) and item.get("type") == "function_call_output" + ] == ["call_DIAG", "call_CLEANUP1", "call_CLEANUP2"] + + def test_hydrate_from_state_does_not_track_string_initial_input_by_object_identity() -> None: tracker = OpenAIServerConversationTracker( conversation_id="conv-init-string", previous_response_id=None