Skip to content

Commit 16e0409

Browse files
authored
fix: #3004 serve HITL resume tool outputs (#3006)
1 parent 5be06a1 commit 16e0409

5 files changed

Lines changed: 203 additions & 34 deletions

File tree

src/agents/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
build_resumed_stream_debug_extra,
5353
ensure_context_wrapper,
5454
finalize_conversation_tracking,
55+
get_unsent_tool_call_ids_for_interrupted_state,
5556
input_guardrails_triggered,
5657
resolve_processed_response,
5758
resolve_resumed_context,
@@ -570,6 +571,7 @@ async def run(
570571
generated_items=run_state._generated_items,
571572
model_responses=run_state._model_responses,
572573
session_items=session_input_items,
574+
unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(run_state),
573575
)
574576

575577
tool_use_tracker = AgentToolUseTracker()

src/agents/run_internal/agent_runner_helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from collections.abc import Mapping
56
from typing import Any, cast
67

78
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
@@ -49,6 +50,7 @@
4950
"describe_run_state_step",
5051
"ensure_context_wrapper",
5152
"finalize_conversation_tracking",
53+
"get_unsent_tool_call_ids_for_interrupted_state",
5254
"input_guardrails_triggered",
5355
"validate_session_conversation_settings",
5456
"resolve_trace_settings",
@@ -184,6 +186,41 @@ def apply_resumed_conversation_settings(
184186
return conversation_id, previous_response_id, auto_previous_response_id
185187

186188

189+
def _extract_tool_call_id(raw: Any) -> str | None:
190+
if isinstance(raw, Mapping):
191+
candidate = raw.get("call_id") or raw.get("id")
192+
else:
193+
candidate = getattr(raw, "call_id", None) or getattr(raw, "id", None)
194+
return candidate if isinstance(candidate, str) else None
195+
196+
197+
def get_unsent_tool_call_ids_for_interrupted_state(run_state: RunState[Any] | None) -> set[str]:
198+
"""Return tool call IDs whose local outputs belong to the current interruption."""
199+
if run_state is None or not isinstance(run_state._current_step, NextStepInterruption):
200+
return set()
201+
202+
processed_response = run_state._last_processed_response
203+
if processed_response is None:
204+
return set()
205+
206+
tool_call_ids: set[str] = set()
207+
tool_run_groups = (
208+
processed_response.handoffs,
209+
processed_response.functions,
210+
processed_response.computer_actions,
211+
processed_response.custom_tool_calls,
212+
processed_response.local_shell_calls,
213+
processed_response.shell_calls,
214+
processed_response.apply_patch_calls,
215+
)
216+
for tool_runs in tool_run_groups:
217+
for tool_run in tool_runs:
218+
call_id = _extract_tool_call_id(getattr(tool_run, "tool_call", None))
219+
if call_id is not None:
220+
tool_call_ids.add(call_id)
221+
return tool_call_ids
222+
223+
187224
def validate_session_conversation_settings(
188225
session: Session | None,
189226
*,

src/agents/run_internal/oai_conversation.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ def _is_tool_search_item(item: Any) -> bool:
8484
return item_type in {"tool_search_call", "tool_search_output"}
8585

8686

87+
def _extract_call_id(item: Any) -> str | None:
88+
"""Return a tool call id from mapping or object payloads."""
89+
call_id = item.get("call_id") if isinstance(item, dict) else getattr(item, "call_id", None)
90+
return call_id if isinstance(call_id, str) else None
91+
92+
93+
def _has_output_payload(item: Any) -> bool:
94+
"""Return True when an item carries a local tool output payload."""
95+
return (isinstance(item, dict) and "output" in item) or hasattr(item, "output")
96+
97+
8798
@dataclass
8899
class OpenAIServerConversationTracker:
89100
"""Track server-side conversation state for conversation-aware runs.
@@ -141,6 +152,7 @@ def hydrate_from_state(
141152
generated_items: list[RunItem],
142153
model_responses: list[ModelResponse],
143154
session_items: list[TResponseInputItem] | None = None,
155+
unsent_tool_call_ids: set[str] | None = None,
144156
) -> None:
145157
"""Seed tracking from prior state so resumed runs do not replay already-sent content.
146158
@@ -151,6 +163,7 @@ def hydrate_from_state(
151163
"""
152164
if self.sent_initial_input:
153165
return
166+
unsent_tool_call_ids = unsent_tool_call_ids or set()
154167

155168
normalized_input = original_input
156169
if isinstance(original_input, list):
@@ -189,13 +202,8 @@ def hydrate_from_state(
189202
)
190203
if item_id is not None:
191204
self.server_item_ids.add(item_id)
192-
call_id = (
193-
output_item.get("call_id")
194-
if isinstance(output_item, dict)
195-
else getattr(output_item, "call_id", None)
196-
)
197-
has_output_payload = isinstance(output_item, dict) and "output" in output_item
198-
has_output_payload = has_output_payload or hasattr(output_item, "output")
205+
call_id = _extract_call_id(output_item)
206+
has_output_payload = _has_output_payload(output_item)
199207
if isinstance(call_id, str) and has_output_payload:
200208
self.server_tool_call_ids.add(call_id)
201209

@@ -209,13 +217,8 @@ def hydrate_from_state(
209217
)
210218
if item_id is not None:
211219
self.server_item_ids.add(item_id)
212-
call_id = (
213-
item.get("call_id")
214-
if isinstance(item, dict)
215-
else getattr(item, "call_id", None)
216-
)
217-
has_output = isinstance(item, dict) and "output" in item
218-
has_output = has_output or hasattr(item, "output")
220+
call_id = _extract_call_id(item)
221+
has_output = _has_output_payload(item)
219222
if isinstance(call_id, str) and has_output:
220223
self.server_tool_call_ids.add(call_id)
221224
fp = _fingerprint_for_tracker(item)
@@ -237,10 +240,15 @@ def hydrate_from_state(
237240

238241
if isinstance(raw_item, dict):
239242
item_id = _normalize_server_item_id(raw_item.get("id"))
240-
call_id = raw_item.get("call_id")
241-
has_output_payload = "output" in raw_item
242-
has_output_payload = has_output_payload or hasattr(raw_item, "output")
243+
call_id = _extract_call_id(raw_item)
244+
has_output_payload = _has_output_payload(raw_item)
243245
has_call_id = isinstance(call_id, str)
246+
if (
247+
isinstance(call_id, str)
248+
and has_output_payload
249+
and call_id in unsent_tool_call_ids
250+
):
251+
continue
244252
should_mark = (
245253
item_id is not None
246254
or (has_call_id and (has_output_payload or is_tool_call_item))
@@ -266,9 +274,15 @@ def hydrate_from_state(
266274
self.server_tool_call_ids.add(call_id)
267275
else:
268276
item_id = _normalize_server_item_id(getattr(raw_item, "id", None))
269-
call_id = getattr(raw_item, "call_id", None)
270-
has_output_payload = hasattr(raw_item, "output")
277+
call_id = _extract_call_id(raw_item)
278+
has_output_payload = _has_output_payload(raw_item)
271279
has_call_id = isinstance(call_id, str)
280+
if (
281+
isinstance(call_id, str)
282+
and has_output_payload
283+
and call_id in unsent_tool_call_ids
284+
):
285+
continue
272286
should_mark = (
273287
item_id is not None
274288
or (has_call_id and (has_output_payload or is_tool_call_item))
@@ -309,13 +323,8 @@ def track_server_items(self, model_response: ModelResponse | None) -> None:
309323
)
310324
if item_id is not None:
311325
self.server_item_ids.add(item_id)
312-
call_id = (
313-
output_item.get("call_id")
314-
if isinstance(output_item, dict)
315-
else getattr(output_item, "call_id", None)
316-
)
317-
has_output_payload = isinstance(output_item, dict) and "output" in output_item
318-
has_output_payload = has_output_payload or hasattr(output_item, "output")
326+
call_id = _extract_call_id(output_item)
327+
has_output_payload = _has_output_payload(output_item)
319328
if isinstance(call_id, str) and has_output_payload:
320329
self.server_tool_call_ids.add(call_id)
321330
fp = _fingerprint_for_tracker(output_item)
@@ -445,13 +454,8 @@ def prepare_input(
445454
if item_id is not None and item_id in self.server_item_ids:
446455
continue
447456

448-
call_id = (
449-
raw_item.get("call_id")
450-
if isinstance(raw_item, dict)
451-
else getattr(raw_item, "call_id", None)
452-
)
453-
has_output_payload = isinstance(raw_item, dict) and "output" in raw_item
454-
has_output_payload = has_output_payload or hasattr(raw_item, "output")
457+
call_id = _extract_call_id(raw_item)
458+
has_output_payload = _has_output_payload(raw_item)
455459
if (
456460
isinstance(call_id, str)
457461
and has_output_payload

src/agents/run_internal/run_loop.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from .agent_runner_helpers import (
8787
apply_resumed_conversation_settings,
8888
attach_usage_to_span,
89+
get_unsent_tool_call_ids_for_interrupted_state,
8990
snapshot_usage,
9091
usage_delta,
9192
)
@@ -577,6 +578,7 @@ def _sync_conversation_tracking_from_tracker() -> None:
577578
generated_items=run_state._generated_items,
578579
model_responses=run_state._model_responses,
579580
session_items=session_items,
581+
unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(run_state),
580582
)
581583

582584
streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))

tests/test_server_conversation_tracker.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
1+
from types import SimpleNamespace
12
from typing import Any, cast
23

34
import pytest
5+
from openai.types.responses import ResponseFunctionToolCall
46
from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool
57

68
from agents import Agent, HostedMCPTool
7-
from agents.items import MCPListToolsItem, ModelResponse, RunItem, ToolCallItem, TResponseInputItem
9+
from agents.items import (
10+
MCPListToolsItem,
11+
ModelResponse,
12+
RunItem,
13+
ToolApprovalItem,
14+
ToolCallItem,
15+
ToolCallOutputItem,
16+
TResponseInputItem,
17+
)
818
from agents.lifecycle import RunHooks
919
from agents.models.fake_id import FAKE_RESPONSES_ID
1020
from agents.result import RunResultStreaming
1121
from agents.run_config import ModelInputData, RunConfig
1222
from agents.run_context import RunContextWrapper
1323
from agents.run_internal.agent_bindings import bind_public_agent
24+
from agents.run_internal.agent_runner_helpers import get_unsent_tool_call_ids_for_interrupted_state
1425
from agents.run_internal.oai_conversation import OpenAIServerConversationTracker
1526
from agents.run_internal.run_loop import get_new_response, run_single_turn_streamed
27+
from agents.run_internal.run_steps import NextStepInterruption
1628
from agents.run_internal.tool_use_tracker import AgentToolUseTracker
1729
from agents.stream_events import RunItemStreamEvent
1830
from agents.usage import Usage
@@ -85,6 +97,118 @@ def test_prepare_input_filters_items_seen_by_server_and_tool_calls() -> None:
8597
assert tracker.remaining_initial_input is None
8698

8799

100+
def test_hydrate_from_state_preserves_unsent_outputs_from_interrupted_turn() -> None:
101+
agent = Agent(name="test")
102+
cleanup1_call = ResponseFunctionToolCall(
103+
id="fc_001",
104+
type="function_call",
105+
call_id="call_CLEANUP1",
106+
name="run_cleanup",
107+
arguments='{"target": "temp_files"}',
108+
status="completed",
109+
)
110+
diagnostic_call = ResponseFunctionToolCall(
111+
id="fc_002",
112+
type="function_call",
113+
call_id="call_DIAG",
114+
name="run_diagnostic",
115+
arguments='{"check_name": "thermal"}',
116+
status="completed",
117+
)
118+
cleanup2_call = ResponseFunctionToolCall(
119+
id="fc_003",
120+
type="function_call",
121+
call_id="call_CLEANUP2",
122+
name="run_cleanup",
123+
arguments='{"target": "winsxs_cache"}',
124+
status="completed",
125+
)
126+
model_response = ModelResponse(
127+
output=[cleanup1_call, diagnostic_call, cleanup2_call],
128+
usage=Usage(),
129+
response_id="resp_002",
130+
)
131+
diagnostic_output = ToolCallOutputItem(
132+
agent=agent,
133+
raw_item={
134+
"type": "function_call_output",
135+
"call_id": "call_DIAG",
136+
"output": "Diagnostic completed.",
137+
},
138+
output="Diagnostic completed.",
139+
)
140+
generated_items: list[RunItem] = [
141+
ToolCallItem(agent=agent, raw_item=cleanup1_call),
142+
ToolCallItem(agent=agent, raw_item=diagnostic_call),
143+
ToolCallItem(agent=agent, raw_item=cleanup2_call),
144+
diagnostic_output,
145+
ToolApprovalItem(agent=agent, raw_item=cleanup1_call, tool_name="run_cleanup"),
146+
ToolApprovalItem(agent=agent, raw_item=cleanup2_call, tool_name="run_cleanup"),
147+
]
148+
interrupted_state = SimpleNamespace(
149+
_current_step=NextStepInterruption(interruptions=[]),
150+
_last_processed_response=SimpleNamespace(
151+
handoffs=[],
152+
functions=[
153+
SimpleNamespace(tool_call=cleanup1_call),
154+
SimpleNamespace(tool_call=diagnostic_call),
155+
SimpleNamespace(tool_call=cleanup2_call),
156+
],
157+
computer_actions=[],
158+
custom_tool_calls=[],
159+
local_shell_calls=[],
160+
shell_calls=[],
161+
apply_patch_calls=[],
162+
),
163+
)
164+
165+
tracker = OpenAIServerConversationTracker(previous_response_id="resp_002")
166+
tracker.hydrate_from_state(
167+
original_input="Run cleanup, diagnostics, and cleanup.",
168+
generated_items=generated_items,
169+
model_responses=[model_response],
170+
unsent_tool_call_ids=get_unsent_tool_call_ids_for_interrupted_state(
171+
cast(Any, interrupted_state)
172+
),
173+
)
174+
175+
assert "call_DIAG" not in tracker.server_tool_call_ids
176+
177+
prepared = tracker.prepare_input(
178+
"Run cleanup, diagnostics, and cleanup.",
179+
[
180+
ToolCallItem(agent=agent, raw_item=cleanup1_call),
181+
ToolCallItem(agent=agent, raw_item=diagnostic_call),
182+
ToolCallItem(agent=agent, raw_item=cleanup2_call),
183+
diagnostic_output,
184+
ToolCallOutputItem(
185+
agent=agent,
186+
raw_item={
187+
"type": "function_call_output",
188+
"call_id": "call_CLEANUP1",
189+
"output": "Tool call not approved.",
190+
},
191+
output="Tool call not approved.",
192+
),
193+
ToolCallOutputItem(
194+
agent=agent,
195+
raw_item={
196+
"type": "function_call_output",
197+
"call_id": "call_CLEANUP2",
198+
"output": "Tool call not approved.",
199+
},
200+
output="Tool call not approved.",
201+
),
202+
],
203+
)
204+
205+
assert [
206+
item.get("call_id")
207+
for item in prepared
208+
if isinstance(item, dict) and item.get("type") == "function_call_output"
209+
] == ["call_DIAG", "call_CLEANUP1", "call_CLEANUP2"]
210+
211+
88212
def test_hydrate_from_state_does_not_track_string_initial_input_by_object_identity() -> None:
89213
tracker = OpenAIServerConversationTracker(
90214
conversation_id="conv-init-string", previous_response_id=None

0 commit comments

Comments
 (0)