Skip to content

Commit 1f93475

Browse files
committed
fix: thread run context wrapper through session retry-rewind and compaction paths
Addresses review feedback: a wrapper-aware session that scopes storage by wrapper.context was getting the wrapper on normal reads/writes but not on the conversation-retry rewind path or the compaction decorator's internal history reads/replacements, so those operated on the unscoped/default store. - Thread the wrapper through rewind_session_items, the tail-suffix rewind, cleanup verification, and popped-item restoration. - Thread the wrapper through OpenAIResponsesCompactionSession's run_compaction, candidate loading, and replace/restore helpers, and add it to the run_compaction protocol method. clear_session/pop_item keep their existing signatures, matching the get_items/add_items-only scope of this change.
1 parent cdf4f03 commit 1f93475

5 files changed

Lines changed: 160 additions & 38 deletions

File tree

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def _resolve_compaction_mode_for_response(
159159
return "input"
160160
return _resolve_compaction_mode(mode, response_id=response_id, store=store)
161161

162-
async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None:
162+
async def run_compaction(
163+
self,
164+
args: OpenAIResponsesCompactionArgs | None = None,
165+
*,
166+
wrapper: RunContextWrapper[Any] | None = None,
167+
) -> None:
163168
"""Run compaction using responses.compact API."""
164169
if args and args.get("response_id"):
165170
self._response_id = args["response_id"]
@@ -184,7 +189,9 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
184189
"when using previous_response_id compaction."
185190
)
186191

187-
compaction_candidate_items, session_items = await self._ensure_compaction_candidates()
192+
compaction_candidate_items, session_items = await self._ensure_compaction_candidates(
193+
wrapper=wrapper
194+
)
188195

189196
force = args.get("force", False) if args else False
190197
should_compact = force or self.should_trigger_compaction(
@@ -220,10 +227,11 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
220227
_normalize_compaction_output_items(compacted.output or [])
221228
)
222229

223-
previous_items = await self._get_all_underlying_session_items()
230+
previous_items = await self._get_all_underlying_session_items(wrapper=wrapper)
224231
await self._replace_underlying_session_items(
225232
output_items=output_items,
226233
previous_items=previous_items,
234+
wrapper=wrapper,
227235
)
228236

229237
self._compaction_candidate_items = select_compaction_candidate_items(output_items)
@@ -235,7 +243,7 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None
235243
f"candidates={len(self._compaction_candidate_items)})"
236244
)
237245

238-
async def get_items(
246+
async def _underlying_get_items(
239247
self,
240248
limit: int | None = None,
241249
*,
@@ -249,37 +257,65 @@ async def get_items(
249257
return await self.underlying_session.get_items(limit, wrapper=wrapper)
250258
return await self.underlying_session.get_items(limit)
251259

252-
async def _get_all_underlying_session_items(self) -> list[TResponseInputItem]:
253-
return await self.underlying_session.get_items(limit=_ALL_SESSION_ITEMS_LIMIT)
260+
async def _underlying_add_items(
261+
self,
262+
items: list[TResponseInputItem],
263+
*,
264+
wrapper: RunContextWrapper[Any] | None = None,
265+
) -> None:
266+
if wrapper is not None and session_method_accepts_wrapper(
267+
self.underlying_session.add_items
268+
):
269+
await self.underlying_session.add_items(items, wrapper=wrapper)
270+
return
271+
await self.underlying_session.add_items(items)
272+
273+
async def get_items(
274+
self,
275+
limit: int | None = None,
276+
*,
277+
wrapper: RunContextWrapper[Any] | None = None,
278+
) -> list[TResponseInputItem]:
279+
return await self._underlying_get_items(limit, wrapper=wrapper)
280+
281+
async def _get_all_underlying_session_items(
282+
self, *, wrapper: RunContextWrapper[Any] | None = None
283+
) -> list[TResponseInputItem]:
284+
return await self._underlying_get_items(_ALL_SESSION_ITEMS_LIMIT, wrapper=wrapper)
254285

255286
async def _replace_underlying_session_items(
256287
self,
257288
*,
258289
output_items: list[TResponseInputItem],
259290
previous_items: list[TResponseInputItem],
291+
wrapper: RunContextWrapper[Any] | None = None,
260292
) -> None:
261293
try:
262294
await self.underlying_session.clear_session()
263295
except Exception as clear_error:
264296
await self._restore_underlying_session_items_after_failed_clear(
265-
previous_items, clear_error
297+
previous_items, clear_error, wrapper=wrapper
266298
)
267299
raise
268300

269301
try:
270302
if output_items:
271-
await self.underlying_session.add_items(output_items)
303+
await self._underlying_add_items(output_items, wrapper=wrapper)
272304
except Exception as replacement_error:
273-
await self._restore_underlying_session_items(previous_items, replacement_error)
305+
await self._restore_underlying_session_items(
306+
previous_items, replacement_error, wrapper=wrapper
307+
)
274308
raise
275309

276310
async def _restore_underlying_session_items_after_failed_clear(
277311
self,
278312
previous_items: list[TResponseInputItem],
279313
clear_error: Exception,
314+
*,
315+
wrapper: RunContextWrapper[Any] | None = None,
280316
) -> None:
281317
try:
282-
current_items = await self._get_all_underlying_session_items()
318+
current_items = await self._get_all_underlying_session_items(wrapper=wrapper)
283319
except Exception:
284320
logger.warning(
285321
"Failed to inspect session history after compaction replacement clear failed.",
@@ -291,7 +327,7 @@ async def _restore_underlying_session_items_after_failed_clear(
291327
return
292328

293329
await self._restore_underlying_session_items(
294-
previous_items, clear_error, clear_existing_items=False
330+
previous_items, clear_error, clear_existing_items=False, wrapper=wrapper
295331
)
296332

297333
async def _restore_underlying_session_items(
@@ -300,12 +336,13 @@ async def _restore_underlying_session_items(
300336
replacement_error: Exception,
301337
*,
302338
clear_existing_items: bool = True,
339+
wrapper: RunContextWrapper[Any] | None = None,
303340
) -> None:
304341
try:
305342
if clear_existing_items:
306343
await self.underlying_session.clear_session()
307344
if previous_items:
308-
await self.underlying_session.add_items(list(previous_items))
345+
await self._underlying_add_items(list(previous_items), wrapper=wrapper)
309346
except Exception:
310347
logger.warning(
311348
"Failed to restore session history after compaction replacement failed.",
@@ -318,10 +355,18 @@ async def _restore_underlying_session_items(
318355
replacement_error,
319356
)
320357

321-
async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None:
358+
async def _defer_compaction(
359+
self,
360+
response_id: str,
361+
store: bool | None = None,
362+
*,
363+
wrapper: RunContextWrapper[Any] | None = None,
364+
) -> None:
322365
if self._deferred_response_id is not None:
323366
return
324-
compaction_candidate_items, session_items = await self._ensure_compaction_candidates()
367+
compaction_candidate_items, session_items = await self._ensure_compaction_candidates(
368+
wrapper=wrapper
369+
)
325370
resolved_mode = self._resolve_compaction_mode_for_response(
326371
response_id=response_id,
327372
store=store,
@@ -350,12 +395,7 @@ async def add_items(
350395
*,
351396
wrapper: RunContextWrapper[Any] | None = None,
352397
) -> None:
353-
if wrapper is not None and session_method_accepts_wrapper(
354-
self.underlying_session.add_items
355-
):
356-
await self.underlying_session.add_items(items, wrapper=wrapper)
357-
else:
358-
await self.underlying_session.add_items(items)
398+
await self._underlying_add_items(items, wrapper=wrapper)
359399
if self._compaction_candidate_items is not None:
360400
new_items = _normalize_compaction_session_items(items)
361401
new_candidates = select_compaction_candidate_items(new_items)
@@ -379,12 +419,16 @@ async def clear_session(self) -> None:
379419

380420
async def _ensure_compaction_candidates(
381421
self,
422+
*,
423+
wrapper: RunContextWrapper[Any] | None = None,
382424
) -> tuple[list[TResponseInputItem], list[TResponseInputItem]]:
383425
"""Lazy-load and cache compaction candidates."""
384426
if self._compaction_candidate_items is not None and self._session_items is not None:
385427
return (self._compaction_candidate_items[:], self._session_items[:])
386428

387-
history = _normalize_compaction_session_items(await self.underlying_session.get_items())
429+
history = _normalize_compaction_session_items(
430+
await self._underlying_get_items(wrapper=wrapper)
431+
)
388432
candidates = select_compaction_candidate_items(history)
389433
self._compaction_candidate_items = candidates
390434
self._session_items = history

src/agents/memory/session.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,19 @@ class OpenAIResponsesCompactionArgs(TypedDict, total=False):
162162
class OpenAIResponsesCompactionAwareSession(Session, Protocol):
163163
"""Protocol for session implementations that support responses compaction."""
164164

165-
async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None:
166-
"""Run the compaction process for the session."""
165+
async def run_compaction(
166+
self,
167+
args: OpenAIResponsesCompactionArgs | None = None,
168+
*,
169+
wrapper: RunContextWrapper[Any] | None = None,
170+
) -> None:
171+
"""Run the compaction process for the session.
172+
173+
Args:
174+
args: Optional compaction arguments.
175+
wrapper: Optional run context wrapper for the current run. Implementations
176+
may ignore this parameter.
177+
"""
167178
...
168179

169180

src/agents/run_internal/run_loop.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,7 +1463,9 @@ def _tool_search_fingerprint(raw_item: Any) -> str:
14631463

14641464
async def rewind_model_request() -> None:
14651465
items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else []
1466-
await rewind_session_items(session, items_to_rewind, server_conversation_tracker)
1466+
await rewind_session_items(
1467+
session, items_to_rewind, server_conversation_tracker, wrapper=context_wrapper
1468+
)
14671469
if server_conversation_tracker is not None:
14681470
server_conversation_tracker.rewind_input(filtered.input)
14691471

@@ -1887,7 +1889,9 @@ async def get_new_response(
18871889

18881890
async def rewind_model_request() -> None:
18891891
items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else []
1890-
await rewind_session_items(session, items_to_rewind, server_conversation_tracker)
1892+
await rewind_session_items(
1893+
session, items_to_rewind, server_conversation_tracker, wrapper=context_wrapper
1894+
)
18911895
if server_conversation_tracker is not None:
18921896
server_conversation_tracker.rewind_input(filtered.input)
18931897

src/agents/run_internal/session_persistence.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,10 @@ async def save_result_to_session(
398398
if has_local_tool_outputs:
399399
defer_compaction = getattr(session, "_defer_compaction", None)
400400
if callable(defer_compaction):
401-
result = defer_compaction(response_id, store=store)
401+
if session_method_accepts_wrapper(defer_compaction):
402+
result = defer_compaction(response_id, store=store, wrapper=wrapper)
403+
else:
404+
result = defer_compaction(response_id, store=store)
402405
if inspect.isawaitable(result):
403406
await result
404407
logger.debug(
@@ -424,7 +427,10 @@ async def save_result_to_session(
424427
}
425428
if store is not None:
426429
compaction_args["store"] = store
427-
await session.run_compaction(compaction_args)
430+
if session_method_accepts_wrapper(session.run_compaction):
431+
await session.run_compaction(compaction_args, wrapper=wrapper)
432+
else:
433+
await session.run_compaction(compaction_args)
428434

429435
return saved_run_items_count
430436

@@ -459,6 +465,8 @@ async def rewind_session_items(
459465
session: Session | None,
460466
items: Sequence[TResponseInputItem],
461467
server_tracker: OpenAIServerConversationTracker | None = None,
468+
*,
469+
wrapper: RunContextWrapper[Any] | None = None,
462470
) -> None:
463471
"""
464472
Best-effort helper to roll back items recently persisted to a session when a conversation
@@ -499,6 +507,7 @@ async def rewind_session_items(
499507
"Skipping session rewind because the current tail does not match the retry-owned suffix"
500508
),
501509
pop_failure_warning="Failed to rewind session item: %s",
510+
wrapper=wrapper,
502511
)
503512
if not rewound:
504513
return
@@ -507,13 +516,14 @@ async def rewind_session_items(
507516
session,
508517
snapshot_serializations,
509518
ignore_ids_for_matching=ignore_ids_for_matching,
519+
wrapper=wrapper,
510520
)
511521

512522
if session is None or server_tracker is None:
513523
return
514524

515525
try:
516-
latest_items = await session.get_items(limit=1)
526+
latest_items = await _session_get_items(session, limit=1, wrapper=wrapper)
517527
except Exception as exc:
518528
logger.debug("Failed to peek session items while rewinding: %s", exc)
519529
return
@@ -526,7 +536,7 @@ async def rewind_session_items(
526536
return
527537

528538
try:
529-
session_items = await session.get_items()
539+
session_items = await _session_get_items(session, wrapper=wrapper)
530540
except Exception as exc:
531541
logger.debug("Failed to inspect session tail while stripping stray items: %s", exc)
532542
return
@@ -554,6 +564,7 @@ async def rewind_session_items(
554564
"retry-owned conversation items"
555565
),
556566
pop_failure_warning="Failed to strip stray session item: %s",
567+
wrapper=wrapper,
557568
)
558569

559570

@@ -563,6 +574,7 @@ async def wait_for_session_cleanup(
563574
*,
564575
max_attempts: int = 5,
565576
ignore_ids_for_matching: bool = False,
577+
wrapper: RunContextWrapper[Any] | None = None,
566578
) -> None:
567579
"""
568580
Confirm that rewound items are no longer present in the session tail so the store stays
@@ -575,7 +587,7 @@ async def wait_for_session_cleanup(
575587

576588
for attempt in range(max_attempts):
577589
try:
578-
tail_items = await session.get_items(limit=window)
590+
tail_items = await _session_get_items(session, limit=window, wrapper=wrapper)
579591
except Exception as exc:
580592
logger.debug("Failed to verify session cleanup (attempt %d): %s", attempt + 1, exc)
581593
await asyncio.sleep(0.1 * (attempt + 1))
@@ -700,13 +712,16 @@ async def _rewind_session_tail_suffix(
700712
ignore_ids_for_matching: bool,
701713
mismatch_warning: str,
702714
pop_failure_warning: str,
715+
wrapper: RunContextWrapper[Any] | None = None,
703716
) -> bool:
704717
"""Remove an exact serialized suffix from the session tail, aborting when the tail diverges."""
705718
if not expected_serializations:
706719
return True
707720

708721
try:
709-
tail_items = await session.get_items(limit=len(expected_serializations))
722+
tail_items = await _session_get_items(
723+
session, limit=len(expected_serializations), wrapper=wrapper
724+
)
710725
except Exception as exc:
711726
logger.warning(pop_failure_warning, exc)
712727
return False
@@ -734,12 +749,12 @@ async def _rewind_session_tail_suffix(
734749
if inspect.isawaitable(result):
735750
result = await result
736751
except Exception as exc:
737-
await _restore_popped_session_items(session, popped_items)
752+
await _restore_popped_session_items(session, popped_items, wrapper=wrapper)
738753
logger.warning(pop_failure_warning, exc)
739754
return False
740755

741756
if result is None:
742-
await _restore_popped_session_items(session, popped_items)
757+
await _restore_popped_session_items(session, popped_items, wrapper=wrapper)
743758
logger.warning(mismatch_warning)
744759
return False
745760

@@ -748,15 +763,18 @@ async def _rewind_session_tail_suffix(
748763
result, ignore_ids_for_matching=ignore_ids_for_matching
749764
)
750765
if popped_serialized != expected:
751-
await _restore_popped_session_items(session, popped_items)
766+
await _restore_popped_session_items(session, popped_items, wrapper=wrapper)
752767
logger.warning(mismatch_warning)
753768
return False
754769

755770
return True
756771

757772

758773
async def _restore_popped_session_items(
759-
session: Session, popped_items: Sequence[TResponseInputItem]
774+
session: Session,
775+
popped_items: Sequence[TResponseInputItem],
776+
*,
777+
wrapper: RunContextWrapper[Any] | None = None,
760778
) -> None:
761779
"""Best-effort restoration for items popped during a failed rewind attempt."""
762780
if not popped_items:
@@ -767,9 +785,7 @@ async def _restore_popped_session_items(
767785
return
768786

769787
try:
770-
result = add_items(list(reversed(popped_items)))
771-
if inspect.isawaitable(result):
772-
await result
788+
await _session_add_items(session, list(reversed(popped_items)), wrapper=wrapper)
773789
except Exception as exc:
774790
logger.warning("Failed to restore session items after a rewind mismatch: %s", exc)
775791

0 commit comments

Comments
 (0)