Skip to content

Commit 601ecf5

Browse files
authored
fix: #3123 avoid replaying assistant conversation item IDs for OpenAIConversationsSession (#3127)
1 parent 574a598 commit 601ecf5

2 files changed

Lines changed: 292 additions & 16 deletions

File tree

src/agents/run_internal/session_persistence.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ async def prepare_input_with_session(
8686
history = await session.get_items(limit=resolved_settings.limit)
8787
else:
8888
history = await session.get_items()
89+
is_openai_conversation_session = isinstance(session, OpenAIConversationsSession)
8990
converted_history = [
9091
strip_internal_input_item_metadata(ensure_input_item_format(item)) for item in history
9192
]
@@ -122,28 +123,38 @@ async def prepare_input_with_session(
122123
# The callback may reorder, drop, or duplicate items. Keep separate reference maps for
123124
# the copied history and copied new-input lists so we can reconstruct which output items
124125
# belong to the new turn and therefore still need to be persisted.
125-
history_refs = _build_reference_map(history_for_callback)
126+
history_refs = _build_reference_map(
127+
history_for_callback,
128+
ignore_openai_conversation_item_ids=is_openai_conversation_session,
129+
)
126130
new_refs = _build_reference_map(new_items_for_callback)
127-
history_counts = _build_frequency_map(history_for_callback)
131+
history_counts = _build_frequency_map(
132+
history_for_callback,
133+
ignore_openai_conversation_item_ids=is_openai_conversation_session,
134+
)
128135
new_counts = _build_frequency_map(new_items_for_callback)
129136

130137
appended: list[Any] = []
131138
for combined_index, item in enumerate(combined):
132-
key = _session_item_key(item)
133-
if _consume_reference(new_refs, key, item):
134-
new_counts[key] = max(new_counts.get(key, 0) - 1, 0)
139+
history_key = _session_item_key(
140+
item,
141+
ignore_openai_conversation_item_ids=is_openai_conversation_session,
142+
)
143+
new_key = _session_item_key(item)
144+
if _consume_reference(new_refs, new_key, item):
145+
new_counts[new_key] = max(new_counts.get(new_key, 0) - 1, 0)
135146
appended.append(item)
136147
continue
137-
if _consume_reference(history_refs, key, item):
138-
history_counts[key] = max(history_counts.get(key, 0) - 1, 0)
148+
if _consume_reference(history_refs, history_key, item):
149+
history_counts[history_key] = max(history_counts.get(history_key, 0) - 1, 0)
139150
prune_history_indexes.add(combined_index)
140151
continue
141-
if history_counts.get(key, 0) > 0:
142-
history_counts[key] = history_counts.get(key, 0) - 1
152+
if history_counts.get(history_key, 0) > 0:
153+
history_counts[history_key] = history_counts.get(history_key, 0) - 1
143154
prune_history_indexes.add(combined_index)
144155
continue
145-
if new_counts.get(key, 0) > 0:
146-
new_counts[key] = max(new_counts.get(key, 0) - 1, 0)
156+
if new_counts.get(new_key, 0) > 0:
157+
new_counts[new_key] = max(new_counts.get(new_key, 0) - 1, 0)
147158
appended.append(item)
148159
continue
149160
appended.append(item)
@@ -159,6 +170,11 @@ async def prepare_input_with_session(
159170

160171
# Normalize exactly as the runtime does elsewhere so the prepared model input and the
161172
# persisted session items are derived from the same item shape and dedupe rules.
173+
if is_openai_conversation_session and prune_history_indexes:
174+
prepared_items_raw = _sanitize_openai_conversation_history_items_for_model_input(
175+
prepared_items_raw,
176+
prune_history_indexes,
177+
)
162178
prepared_as_inputs = [ensure_input_item_format(item) for item in prepared_items_raw]
163179
filtered = drop_orphan_function_calls(
164180
prepared_as_inputs,
@@ -555,6 +571,32 @@ def _sanitize_openai_conversation_item(item: TResponseInputItem) -> TResponseInp
555571
return item
556572

557573

574+
def _sanitize_openai_conversation_history_items_for_model_input(
575+
items: Sequence[TResponseInputItem],
576+
history_indexes: set[int],
577+
) -> list[TResponseInputItem]:
578+
"""Remove Conversation item metadata only from session-history items sent to the model."""
579+
sanitized_items: list[TResponseInputItem] = []
580+
for index, item in enumerate(items):
581+
if index in history_indexes:
582+
sanitized_items.append(_sanitize_openai_conversation_history_item_for_model_input(item))
583+
else:
584+
sanitized_items.append(item)
585+
return sanitized_items
586+
587+
588+
def _sanitize_openai_conversation_history_item_for_model_input(
589+
item: TResponseInputItem,
590+
) -> TResponseInputItem:
591+
"""Remove Conversation replay metadata from assistant messages only."""
592+
if isinstance(item, dict) and item.get("type") == "message" and item.get("role") == "assistant":
593+
clean_item = cast(dict[str, Any], strip_internal_input_item_metadata(item))
594+
clean_item.pop("id", None)
595+
clean_item.pop("provider_data", None)
596+
return cast(TResponseInputItem, clean_item)
597+
return item
598+
599+
558600
def _fingerprint_or_repr(item: TResponseInputItem, *, ignore_ids_for_matching: bool) -> str:
559601
"""Fingerprint an item or fall back to repr when unavailable."""
560602
return fingerprint_input_item(item, ignore_ids_for_matching=ignore_ids_for_matching) or repr(
@@ -677,7 +719,7 @@ def _collect_retry_owned_tail_serializations(
677719
return []
678720

679721

680-
def _session_item_key(item: Any) -> str:
722+
def _session_item_key(item: Any, *, ignore_openai_conversation_item_ids: bool = False) -> str:
681723
"""Return a stable representation of a session item for comparison."""
682724
try:
683725
if hasattr(item, "model_dump"):
@@ -691,16 +733,30 @@ def _session_item_key(item: Any) -> str:
691733
dict[str, Any],
692734
strip_internal_input_item_metadata(cast(TResponseInputItem, payload)),
693735
)
736+
if ignore_openai_conversation_item_ids:
737+
payload = cast(
738+
dict[str, Any],
739+
_sanitize_openai_conversation_history_item_for_model_input(
740+
cast(TResponseInputItem, payload)
741+
),
742+
)
694743
return json.dumps(payload, sort_keys=True, default=str)
695744
except Exception:
696745
return repr(item)
697746

698747

699-
def _build_reference_map(items: Sequence[Any]) -> dict[str, list[Any]]:
748+
def _build_reference_map(
749+
items: Sequence[Any],
750+
*,
751+
ignore_openai_conversation_item_ids: bool = False,
752+
) -> dict[str, list[Any]]:
700753
"""Map serialized keys to the concrete session items used to build them."""
701754
refs: dict[str, list[Any]] = {}
702755
for item in items:
703-
key = _session_item_key(item)
756+
key = _session_item_key(
757+
item,
758+
ignore_openai_conversation_item_ids=ignore_openai_conversation_item_ids,
759+
)
704760
refs.setdefault(key, []).append(item)
705761
return refs
706762

@@ -719,10 +775,17 @@ def _consume_reference(ref_map: dict[str, list[Any]], key: str, candidate: Any)
719775
return False
720776

721777

722-
def _build_frequency_map(items: Sequence[Any]) -> dict[str, int]:
778+
def _build_frequency_map(
779+
items: Sequence[Any],
780+
*,
781+
ignore_openai_conversation_item_ids: bool = False,
782+
) -> dict[str, int]:
723783
"""Count how many times each serialized key appears in a collection."""
724784
freq: dict[str, int] = {}
725785
for item in items:
726-
key = _session_item_key(item)
786+
key = _session_item_key(
787+
item,
788+
ignore_openai_conversation_item_ids=ignore_openai_conversation_item_ids,
789+
)
727790
freq[key] = freq.get(key, 0) + 1
728791
return freq

tests/test_agent_runner.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2117,6 +2117,219 @@ def callback(
21172117
assert [cast(dict[str, Any], item).get("content") for item in session_items] == ["new"]
21182118

21192119

2120+
@pytest.mark.asyncio
2121+
async def test_prepare_input_with_openai_conversation_strips_assistant_history_ids() -> None:
2122+
class DummyOpenAIConversationsSession(OpenAIConversationsSession):
2123+
def __init__(self, history: list[TResponseInputItem]) -> None:
2124+
self.history = history
2125+
2126+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
2127+
if limit is None:
2128+
return list(self.history)
2129+
return self.history[-limit:]
2130+
2131+
async def add_items(self, items: list[TResponseInputItem]) -> None:
2132+
self.history.extend(items)
2133+
2134+
async def pop_item(self) -> TResponseInputItem | None:
2135+
return self.history.pop() if self.history else None
2136+
2137+
async def clear_session(self) -> None:
2138+
self.history.clear()
2139+
2140+
history_item = cast(
2141+
TResponseInputItem,
2142+
{
2143+
"id": "conv_item_assistant",
2144+
"type": "message",
2145+
"role": "assistant",
2146+
"content": "history",
2147+
"provider_data": {"server": "metadata"},
2148+
},
2149+
)
2150+
user_history_item = cast(
2151+
TResponseInputItem,
2152+
{
2153+
"id": "conv_item_user",
2154+
"type": "message",
2155+
"role": "user",
2156+
"content": "user history",
2157+
"provider_data": {"server": "metadata"},
2158+
},
2159+
)
2160+
function_call_item = cast(
2161+
TResponseInputItem,
2162+
{
2163+
"id": "conv_item_call",
2164+
"type": "function_call",
2165+
"call_id": "call_history",
2166+
"name": "lookup",
2167+
"arguments": "{}",
2168+
},
2169+
)
2170+
function_call_output_item = cast(
2171+
TResponseInputItem,
2172+
{
2173+
"id": "conv_item_output",
2174+
"type": "function_call_output",
2175+
"call_id": "call_history",
2176+
"output": "ok",
2177+
},
2178+
)
2179+
session = DummyOpenAIConversationsSession(
2180+
history=[user_history_item, history_item, function_call_item, function_call_output_item]
2181+
)
2182+
2183+
prepared, session_items = await prepare_input_with_session("new", session, None)
2184+
2185+
assert isinstance(prepared, list)
2186+
user_payload = cast(dict[str, Any], prepared[0])
2187+
history_payload = cast(dict[str, Any], prepared[1])
2188+
call_payload = cast(dict[str, Any], prepared[2])
2189+
output_payload = cast(dict[str, Any], prepared[3])
2190+
new_payload = cast(dict[str, Any], prepared[4])
2191+
assert user_payload["role"] == "user"
2192+
assert user_payload["id"] == "conv_item_user"
2193+
assert "provider_data" in user_payload
2194+
assert history_payload["role"] == "assistant"
2195+
assert "id" not in history_payload
2196+
assert "provider_data" not in history_payload
2197+
assert call_payload["id"] == "conv_item_call"
2198+
assert output_payload["id"] == "conv_item_output"
2199+
assert new_payload["role"] == "user"
2200+
assert new_payload["content"] == "new"
2201+
assert [cast(dict[str, Any], item).get("content") for item in session_items] == ["new"]
2202+
2203+
2204+
@pytest.mark.asyncio
2205+
async def test_prepare_input_with_regular_session_preserves_history_ids() -> None:
2206+
history_item = cast(
2207+
TResponseInputItem,
2208+
{
2209+
"id": "message_id",
2210+
"type": "message",
2211+
"role": "assistant",
2212+
"content": "history",
2213+
},
2214+
)
2215+
session = SimpleListSession(history=[history_item])
2216+
2217+
prepared, _ = await prepare_input_with_session("new", session, None)
2218+
2219+
assert isinstance(prepared, list)
2220+
history_payload = cast(dict[str, Any], prepared[0])
2221+
assert history_payload["id"] == "message_id"
2222+
2223+
2224+
@pytest.mark.asyncio
2225+
async def test_prepare_input_with_openai_conversation_callback_matches_assistant_no_ids() -> None:
2226+
class DummyOpenAIConversationsSession(OpenAIConversationsSession):
2227+
def __init__(self, history: list[TResponseInputItem]) -> None:
2228+
self.history = history
2229+
2230+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
2231+
if limit is None:
2232+
return list(self.history)
2233+
return self.history[-limit:]
2234+
2235+
async def add_items(self, items: list[TResponseInputItem]) -> None:
2236+
self.history.extend(items)
2237+
2238+
async def pop_item(self) -> TResponseInputItem | None:
2239+
return self.history.pop() if self.history else None
2240+
2241+
async def clear_session(self) -> None:
2242+
self.history.clear()
2243+
2244+
history_item = cast(
2245+
TResponseInputItem,
2246+
{
2247+
"id": "conv_item_assistant",
2248+
"type": "message",
2249+
"role": "assistant",
2250+
"content": "history",
2251+
"provider_data": {"server": "metadata"},
2252+
},
2253+
)
2254+
session = DummyOpenAIConversationsSession(history=[history_item])
2255+
2256+
def callback(
2257+
history: list[TResponseInputItem], new_input: list[TResponseInputItem]
2258+
) -> list[TResponseInputItem]:
2259+
history_copy = dict(cast(dict[str, Any], history[0]))
2260+
history_copy.pop("id", None)
2261+
history_copy.pop("provider_data", None)
2262+
return [
2263+
cast(TResponseInputItem, history_copy),
2264+
cast(TResponseInputItem, dict(cast(dict[str, Any], new_input[0]))),
2265+
]
2266+
2267+
prepared, session_items = await prepare_input_with_session("new", session, callback)
2268+
2269+
assert isinstance(prepared, list)
2270+
assert [cast(dict[str, Any], item).get("content") for item in prepared] == [
2271+
"history",
2272+
"new",
2273+
]
2274+
assert [cast(dict[str, Any], item).get("content") for item in session_items] == ["new"]
2275+
2276+
2277+
@pytest.mark.asyncio
2278+
async def test_prepare_input_with_openai_conversation_callback_keeps_user_ids_distinct() -> None:
2279+
class DummyOpenAIConversationsSession(OpenAIConversationsSession):
2280+
def __init__(self, history: list[TResponseInputItem]) -> None:
2281+
self.history = history
2282+
2283+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
2284+
if limit is None:
2285+
return list(self.history)
2286+
return self.history[-limit:]
2287+
2288+
async def add_items(self, items: list[TResponseInputItem]) -> None:
2289+
self.history.extend(items)
2290+
2291+
async def pop_item(self) -> TResponseInputItem | None:
2292+
return self.history.pop() if self.history else None
2293+
2294+
async def clear_session(self) -> None:
2295+
self.history.clear()
2296+
2297+
history_item = cast(
2298+
TResponseInputItem,
2299+
{
2300+
"id": "conv_item_user",
2301+
"type": "message",
2302+
"role": "user",
2303+
"content": "history",
2304+
"provider_data": {"server": "metadata"},
2305+
},
2306+
)
2307+
session = DummyOpenAIConversationsSession(history=[history_item])
2308+
2309+
def callback(
2310+
history: list[TResponseInputItem], new_input: list[TResponseInputItem]
2311+
) -> list[TResponseInputItem]:
2312+
history_copy = dict(cast(dict[str, Any], history[0]))
2313+
history_copy.pop("id", None)
2314+
history_copy.pop("provider_data", None)
2315+
return [
2316+
cast(TResponseInputItem, history_copy),
2317+
cast(TResponseInputItem, dict(cast(dict[str, Any], new_input[0]))),
2318+
]
2319+
2320+
prepared, session_items = await prepare_input_with_session("new", session, callback)
2321+
2322+
assert isinstance(prepared, list)
2323+
assert [cast(dict[str, Any], item).get("content") for item in prepared] == [
2324+
"history",
2325+
"new",
2326+
]
2327+
assert [cast(dict[str, Any], item).get("content") for item in session_items] == [
2328+
"history",
2329+
"new",
2330+
]
2331+
2332+
21202333
@pytest.mark.asyncio
21212334
async def test_persist_session_items_for_guardrail_trip_uses_original_input_when_missing() -> None:
21222335
session = SimpleListSession()

0 commit comments

Comments
 (0)