From cdf4f0317f851e9e699cea5897a553fade5a2cc1 Mon Sep 17 00:00:00 2001 From: jawwad-ali Date: Sat, 6 Jun 2026 15:32:03 +0500 Subject: [PATCH 1/4] feat: pass run context wrapper to session get_items/add_items Sessions can now opt into receiving the current run's RunContextWrapper by accepting a keyword-only wrapper parameter on get_items/add_items. The runner inspects the session's signature and only forwards the wrapper to implementations that accept it, so existing custom sessions keep working unchanged. Wrapper-style sessions (EncryptedSession, OpenAIResponsesCompactionSession) forward the wrapper to opted-in underlying sessions. This revives the approach from #2690, which addressed maintainer review feedback before going stale, rebased onto the current runner internals and extended to MongoDBSession and the guardrail-trip persistence path. Closes #2072 --- docs/sessions/index.md | 46 ++ examples/memory/file_session.py | 15 +- .../memory/advanced_sqlite_session.py | 10 +- .../extensions/memory/async_sqlite_session.py | 17 +- src/agents/extensions/memory/dapr_session.py | 15 +- .../extensions/memory/encrypt_session.py | 39 +- .../extensions/memory/mongodb_session.py | 15 +- src/agents/extensions/memory/redis_session.py | 15 +- .../extensions/memory/sqlalchemy_session.py | 15 +- .../memory/openai_conversations_session.py | 17 +- .../openai_responses_compaction_session.py | 29 +- src/agents/memory/session.py | 63 ++- src/agents/memory/sqlite_session.py | 17 +- src/agents/run.py | 24 +- .../run_internal/agent_runner_helpers.py | 2 + src/agents/run_internal/run_loop.py | 16 +- .../run_internal/session_persistence.py | 50 +- ...est_openai_responses_compaction_session.py | 43 +- tests/memory/test_session.py | 8 +- tests/memory/test_session_context_wrapper.py | 433 ++++++++++++++++++ tests/test_agent_as_tool.py | 14 +- tests/test_agent_runner.py | 117 ++++- tests/test_agent_runner_streamed.py | 16 +- tests/utils/simple_session.py | 24 +- 24 files changed, 984 insertions(+), 76 deletions(-) create mode 100644 tests/memory/test_session_context_wrapper.py diff --git a/docs/sessions/index.md b/docs/sessions/index.md index 8916f85fab..f36bf08d97 100644 --- a/docs/sessions/index.md +++ b/docs/sessions/index.md @@ -680,6 +680,52 @@ result = await Runner.run( ) ``` +### Accessing the run context in custom sessions + +Custom sessions can opt into receiving the current run's [`RunContextWrapper`][agents.run_context.RunContextWrapper] by accepting a keyword-only `wrapper` parameter on `get_items` and `add_items`. The runner passes the wrapper only when the session's signature accepts it, so existing session implementations keep working unchanged: + +```python +from typing import Any + +from agents.memory.session import SessionABC +from agents.items import TResponseInputItem +from agents.run_context import RunContextWrapper + +class ContextAwareSession(SessionABC): + """Session that scopes storage by data from the run context.""" + + def __init__(self, session_id: str): + self.session_id = session_id + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + # Use wrapper.context (e.g. a user ID) to scope retrieval. + user_id = wrapper.context.user_id if wrapper is not None else None + return await self._load_items(user_id, limit) + + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + # Persist items together with data from the run context. + user_id = wrapper.context.user_id if wrapper is not None else None + await self._store_items(user_id, items) + + async def pop_item(self) -> TResponseInputItem | None: + ... + + async def clear_session(self) -> None: + ... +``` + +The `wrapper` parameter may be `None`, for example when session methods are called directly rather than through the runner, so implementations should always handle that case. Sessions that accept `**kwargs` on these methods also receive the wrapper through them. + ## Community session implementations The community has developed additional session implementations: diff --git a/examples/memory/file_session.py b/examples/memory/file_session.py index e62dbd167f..abd6d81717 100644 --- a/examples/memory/file_session.py +++ b/examples/memory/file_session.py @@ -15,6 +15,7 @@ from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class FileSession(Session): @@ -43,14 +44,24 @@ async def get_session_id(self) -> str: """Return the session id, creating one if needed.""" return await self._ensure_session_id() - async def get_items(self, limit: int | None = None) -> list[Any]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[Any]: session_id = await self._ensure_session_id() items = await self._read_items(session_id) if limit is not None and limit >= 0: return items[-limit:] return items - async def add_items(self, items: list[Any]) -> None: + async def add_items( + self, + items: list[Any], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: if not items: return session_id = await self._ensure_session_id() diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..ab560c43a6 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -15,6 +15,7 @@ from ...items import TResponseInputItem from ...memory import SQLiteSession from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class AdvancedSQLiteSession(SQLiteSession): @@ -121,7 +122,12 @@ def _init_structure_tables(self): conn.commit() - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add items to the session. Args: @@ -160,6 +166,8 @@ async def get_items( self, limit: int | None = None, branch_id: str | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> list[TResponseInputItem]: """Get items from current or specified branch. diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 27a23b1cbe..57c206b044 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -5,13 +5,14 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path -from typing import cast +from typing import Any, cast import aiosqlite from ...items import TResponseInputItem from ...memory import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class AsyncSQLiteSession(SessionABC): @@ -106,7 +107,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]: conn = await self._get_connection() yield conn - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -156,7 +162,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index 6ac68f6020..2e24c3f2ff 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -29,6 +29,7 @@ import time from typing import Any, Final, Literal +from ...run_context import RunContextWrapper from ._optional_imports import raise_optional_dependency_error try: @@ -250,7 +251,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -289,7 +295,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/encrypt_session.py b/src/agents/extensions/memory/encrypt_session.py index 19ba7a5683..4ca7192114 100644 --- a/src/agents/extensions/memory/encrypt_session.py +++ b/src/agents/extensions/memory/encrypt_session.py @@ -37,8 +37,9 @@ from typing_extensions import TypedDict from ...items import TResponseInputItem -from ...memory.session import SessionABC +from ...memory.session import SessionABC, session_method_accepts_wrapper from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class EncryptedEnvelope(TypedDict): @@ -180,12 +181,28 @@ def _unwrap_valid_items( valid_items.append(item) return valid_items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def _get_underlying_items( + self, limit: int | None, wrapper: RunContextWrapper[Any] | None + ) -> list[TResponseInputItem]: + # Forward the wrapper only when the underlying session opts in, so wrapping older + # custom sessions keeps working unchanged. + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.get_items + ): + return await self.underlying_session.get_items(limit, wrapper=wrapper) + return await self.underlying_session.get_items(limit) + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: effective_limit = resolve_session_limit(limit, self.session_settings) if effective_limit is not None and effective_limit > 0: window = effective_limit while True: - encrypted_items = await self.underlying_session.get_items(window) + encrypted_items = await self._get_underlying_items(window, wrapper) valid_items = self._unwrap_valid_items(encrypted_items) if len(valid_items) >= effective_limit: return valid_items[-effective_limit:] @@ -193,11 +210,23 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return valid_items window *= 2 - encrypted_items = await self.underlying_session.get_items(limit) + encrypted_items = await self._get_underlying_items(limit, wrapper) return self._unwrap_valid_items(encrypted_items) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.add_items + ): + await self.underlying_session.add_items( + cast(list[TResponseInputItem], wrapped), wrapper=wrapper + ) + return await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) async def pop_item(self) -> TResponseInputItem | None: diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 113acdc6af..5a45ca8599 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -37,6 +37,7 @@ from datetime import datetime, timezone from typing import Any +from ...run_context import RunContextWrapper from ._optional_imports import raise_optional_dependency_error try: @@ -247,7 +248,12 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -289,7 +295,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 11e2dd838b..c20dcf72a2 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -26,6 +26,7 @@ import time from typing import Any +from ...run_context import RunContextWrapper from ._optional_imports import raise_optional_dependency_error try: @@ -145,7 +146,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -184,7 +190,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index fd2502e24b..dbfa042e5f 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -51,6 +51,7 @@ from ...items import TResponseInputItem from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class SQLAlchemySession(SessionABC): @@ -274,7 +275,12 @@ async def _ensure_tables(self) -> None: finally: self._init_lock.release() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -326,7 +332,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 4d4fbaf635..e05caf88de 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -1,10 +1,13 @@ from __future__ import annotations +from typing import Any + from openai import AsyncOpenAI from agents.models._openai_shared import get_default_openai_client from ..items import TResponseInputItem +from ..run_context import RunContextWrapper from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit @@ -70,7 +73,12 @@ async def _get_session_id(self) -> str: async def _clear_session_id(self) -> None: self._session_id = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: session_id = await self._get_session_id() session_limit = resolve_session_limit(limit, self.session_settings) @@ -97,7 +105,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return all_items # type: ignore - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: session_id = await self._get_session_id() if not items: return diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index c112b706a1..48e2309156 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -8,12 +8,14 @@ from ..items import TResponseInputItem from ..models._openai_shared import get_default_openai_client +from ..run_context import RunContextWrapper from ..run_internal.items import normalize_input_items_for_api from .openai_conversations_session import OpenAIConversationsSession from .session import ( OpenAIResponsesCompactionArgs, OpenAIResponsesCompactionAwareSession, SessionABC, + session_method_accepts_wrapper, ) if TYPE_CHECKING: @@ -233,7 +235,18 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None f"candidates={len(self._compaction_candidate_items)})" ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + # Forward the wrapper only when the underlying session opts in, so wrapping older + # custom sessions keeps working unchanged. + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.get_items + ): + return await self.underlying_session.get_items(limit, wrapper=wrapper) return await self.underlying_session.get_items(limit) async def _get_all_underlying_session_items(self) -> list[TResponseInputItem]: @@ -331,8 +344,18 @@ def _get_deferred_compaction_response_id(self) -> str | None: def _clear_deferred_compaction(self) -> None: self._deferred_response_id = None - async def add_items(self, items: list[TResponseInputItem]) -> None: - await self.underlying_session.add_items(items) + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.add_items + ): + await self.underlying_session.add_items(items, wrapper=wrapper) + else: + await self.underlying_session.add_items(items) if self._compaction_candidate_items is not None: new_items = _normalize_compaction_session_items(items) new_candidates = select_compaction_candidate_items(new_items) diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 1781b7ac9f..68ec558b93 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,12 +1,14 @@ from __future__ import annotations +import inspect from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Protocol, TypeGuard, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeGuard, runtime_checkable from typing_extensions import TypedDict if TYPE_CHECKING: from ..items import TResponseInputItem + from ..run_context import RunContextWrapper from .session_settings import SessionSettings @@ -21,23 +23,37 @@ class Session(Protocol): session_id: str session_settings: SessionSettings | None = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. Returns: List of input items representing the conversation history """ ... - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. """ ... @@ -68,12 +84,19 @@ class SessionABC(ABC): session_settings: SessionSettings | None = None @abstractmethod - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. Returns: List of input items representing the conversation history @@ -81,11 +104,18 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: ... @abstractmethod - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. """ ... @@ -148,3 +178,26 @@ def is_openai_responses_compaction_aware_session( except Exception: return False return callable(run_compaction) + + +def session_method_accepts_wrapper(method: Any) -> bool: + """Check if a session method accepts the keyword-only ``wrapper`` argument. + + The runner (and wrapper-style sessions such as ``EncryptedSession``) use this to pass + the current run context only to implementations that opt in, so older custom sessions + that predate the ``wrapper`` parameter keep working unchanged. Methods that accept + ``**kwargs`` are treated as opted in and receive the wrapper through them. + """ + try: + parameters = tuple(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return False + return any( + parameter.kind is inspect.Parameter.VAR_KEYWORD + or ( + parameter.name == "wrapper" + and parameter.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + ) + for parameter in parameters + ) diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 3a69f9883a..695c3dfb7f 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -7,9 +7,10 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path -from typing import ClassVar +from typing import Any, ClassVar from ..items import TResponseInputItem +from ..run_context import RunContextWrapper from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit @@ -199,7 +200,12 @@ def _insert_items(self, conn: sqlite3.Connection, items: list[TResponseInputItem (self.session_id,), ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -254,7 +260,12 @@ def _get_items_sync(): return await asyncio.to_thread(_get_items_sync) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/run.py b/src/agents/run.py index 014271a5ea..27e3c9b039 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -101,6 +101,7 @@ NextStepRunAgain, ) from .run_internal.session_persistence import ( + _session_get_items, persist_session_items_for_guardrail_trip, prepare_input_with_session, resumed_turn_items, @@ -510,6 +511,10 @@ async def run( raw_input = cast(str | list[TResponseInputItem], input) original_user_input = raw_input + context_wrapper = ensure_context_wrapper(context) + context = context_wrapper.context + set_agent_tool_state_scope(context_wrapper, None) + validate_session_conversation_settings( session, conversation_id=conversation_id, @@ -531,6 +536,7 @@ async def run( run_config.session_settings, include_history_in_prepared_input=False, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) original_input_for_state = raw_input session_input_items_for_persistence = [] @@ -543,6 +549,7 @@ async def run( session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, ) original_input_for_state = prepared_input @@ -579,7 +586,7 @@ async def run( session_input_items: list[TResponseInputItem] | None = None if session is not None: try: - session_input_items = await session.get_items() + session_input_items = await _session_get_items(session, wrapper=context_wrapper) except Exception: session_input_items = None server_conversation_tracker.hydrate_from_state( @@ -628,8 +635,6 @@ async def run( generated_items = [] session_items = [] model_responses = [] - context_wrapper = ensure_context_wrapper(context) - set_agent_tool_state_scope(context_wrapper, None) run_state = RunState( context=context_wrapper, original_input=original_input, @@ -754,6 +759,7 @@ def _finalize_result(result: RunResult) -> RunResult: [], run_state, store=store_setting, + wrapper=context_wrapper, ) session_input_items_for_persistence = [] except BaseException: @@ -796,6 +802,7 @@ def _finalize_result(result: RunResult) -> RunResult: original_user_input, run_state, store=store_setting, + wrapper=context_wrapper, ) ) raise @@ -835,6 +842,7 @@ def _finalize_result(result: RunResult) -> RunResult: [], run_state, store=store_setting, + wrapper=context_wrapper, ) session_input_items_for_persistence = [] if run_state is not None and run_state._current_step is not None: @@ -893,6 +901,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state._reasoning_item_id_policy ), store=store_setting, + wrapper=context_wrapper, ) ) @@ -1005,6 +1014,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return _finalize_result(result) @@ -1139,6 +1149,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=None, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return _finalize_result(result) @@ -1187,6 +1198,7 @@ def _finalize_result(result: RunResult) -> RunResult: original_user_input, run_state, store=store_setting, + wrapper=context_wrapper, ) ) raise @@ -1240,6 +1252,7 @@ def _finalize_result(result: RunResult) -> RunResult: original_user_input, run_state, store=store_setting, + wrapper=context_wrapper, ) ) raise @@ -1347,6 +1360,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state._reasoning_item_id_policy ), store=store_setting, + wrapper=context_wrapper, ) run_state._current_turn_persisted_item_count += saved_count else: @@ -1357,6 +1371,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) # After the first resumed turn, treat subsequent turns as fresh @@ -1409,6 +1424,7 @@ def _finalize_result(result: RunResult) -> RunResult: items=session_items_for_turn(turn_result), response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return _finalize_result(result) @@ -1428,6 +1444,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) append_model_response_if_new( model_responses, turn_result.model_response @@ -1489,6 +1506,7 @@ def _finalize_result(result: RunResult) -> RunResult: items=session_items_for_turn(turn_result), response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) continue else: diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 84c67d6b8f..cad1201562 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -467,6 +467,7 @@ async def save_turn_items_if_needed( items: list[RunItem], response_id: str | None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """Persist turn items when persistence is enabled and guardrails allow it.""" if not session_persistence_enabled: @@ -482,6 +483,7 @@ async def save_turn_items_if_needed( run_state, response_id=response_id, store=store, + wrapper=wrapper, ) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..45fe354be1 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -143,6 +143,7 @@ ToolRunShellCall, ) from .session_persistence import ( + _session_get_items, persist_session_items_for_guardrail_trip, prepare_input_with_session, resumed_turn_items, @@ -322,6 +323,7 @@ async def _save_resumed_stream_items( response_id=response_id, reasoning_item_id_policy=streamed_result._reasoning_item_id_policy, store=store, + wrapper=streamed_result.context_wrapper, ) if run_state is not None: run_state._current_turn_persisted_item_count = ( @@ -353,6 +355,7 @@ async def _save_stream_items( run_state, response_id=response_id, store=store, + wrapper=streamed_result.context_wrapper, ) if update_persisted_count and streamed_result._state is not None: streamed_result._current_turn_persisted_item_count = ( @@ -575,7 +578,7 @@ def _sync_conversation_tracking_from_tracker() -> None: session_items: list[TResponseInputItem] | None = None if session is not None: try: - session_items = await session.get_items() + session_items = await _session_get_items(session, wrapper=context_wrapper) except Exception: session_items = None server_conversation_tracker.hydrate_from_state( @@ -603,6 +606,7 @@ def _sync_conversation_tracking_from_tracker() -> None: run_config.session_settings, include_history_in_prepared_input=not server_manages_conversation, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) streamed_result.input = prepared_input streamed_result._original_input = copy_input_items(prepared_input) @@ -706,6 +710,7 @@ async def _save_stream_items_without_count( store=current_agent.model_settings.resolve( run_config.model_settings ).store, + wrapper=context_wrapper, ) ) raise InputGuardrailTripwireTriggered(result) @@ -978,6 +983,7 @@ async def _save_stream_items_without_count( store=current_agent.model_settings.resolve( run_config.model_settings ).store, + wrapper=context_wrapper, ) ) raise InputGuardrailTripwireTriggered(result) @@ -1420,7 +1426,13 @@ def _tool_search_fingerprint(raw_item: Any) -> str: ) ] if input_items_to_save: - await save_result_to_session(session, input_items_to_save, [], streamed_result._state) + await save_result_to_session( + session, + input_items_to_save, + [], + streamed_result._state, + wrapper=context_wrapper, + ) previous_response_id = ( server_conversation_tracker.previous_response_id diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index f483da13a3..dc7cf136b5 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -23,6 +23,8 @@ is_openai_responses_compaction_aware_session, ) from ..memory.openai_conversations_session import OpenAIConversationsSession +from ..memory.session import session_method_accepts_wrapper +from ..run_context import RunContextWrapper from ..run_state import RunState from .items import ( ReasoningItemIdPolicy, @@ -51,6 +53,39 @@ ] +async def _session_get_items( + session: Session, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> list[TResponseInputItem]: + """Call ``session.get_items``, passing the run context only when the session opts in. + + Custom sessions that predate the keyword-only ``wrapper`` parameter keep working + unchanged because the wrapper is only forwarded when their signature accepts it. + """ + if wrapper is not None and session_method_accepts_wrapper(session.get_items): + if limit is not None: + return await session.get_items(limit=limit, wrapper=wrapper) + return await session.get_items(wrapper=wrapper) + if limit is not None: + return await session.get_items(limit=limit) + return await session.get_items() + + +async def _session_add_items( + session: Session, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> None: + """Call ``session.add_items``, passing the run context only when the session opts in.""" + if wrapper is not None and session_method_accepts_wrapper(session.add_items): + await session.add_items(items, wrapper=wrapper) + return + await session.add_items(items) + + async def prepare_input_with_session( input: str | list[TResponseInputItem], session: Session | None, @@ -59,6 +94,7 @@ async def prepare_input_with_session( *, include_history_in_prepared_input: bool = True, preserve_dropped_new_items: bool = False, + wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[str | list[TResponseInputItem], list[TResponseInputItem]]: """Prepare model input from session history plus the new turn input. @@ -83,9 +119,9 @@ async def prepare_input_with_session( resolved_settings = resolved_settings.resolve(session_settings) if resolved_settings.limit is not None: - history = await session.get_items(limit=resolved_settings.limit) + history = await _session_get_items(session, limit=resolved_settings.limit, wrapper=wrapper) else: - history = await session.get_items() + history = await _session_get_items(session, wrapper=wrapper) is_openai_conversation_session = isinstance(session, OpenAIConversationsSession) converted_history = [ strip_internal_input_item_metadata(ensure_input_item_format(item)) for item in history @@ -194,6 +230,7 @@ async def persist_session_items_for_guardrail_trip( original_user_input: str | list[TResponseInputItem] | None, run_state: RunState | None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> list[TResponseInputItem] | None: """ Persist input items when a guardrail tripwire is triggered. @@ -208,7 +245,9 @@ async def persist_session_items_for_guardrail_trip( input_items_for_save: list[TResponseInputItem] = ( updated_session_input_items if updated_session_input_items is not None else [] ) - await save_result_to_session(session, input_items_for_save, [], run_state, store=store) + await save_result_to_session( + session, input_items_for_save, [], run_state, store=store, wrapper=wrapper + ) return updated_session_input_items @@ -253,6 +292,7 @@ async def save_result_to_session( response_id: str | None = None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """ Persist a turn to the session store, keeping track of what was already saved so retries @@ -346,7 +386,7 @@ async def save_result_to_session( run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count return saved_run_items_count - await session.add_items(items_to_save) + await _session_add_items(session, items_to_save, wrapper=wrapper) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count @@ -397,6 +437,7 @@ async def save_resumed_turn_items( response_id: str | None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """Persist resumed turn items and return the updated persisted count.""" if session is None or not items: @@ -409,6 +450,7 @@ async def save_resumed_turn_items( response_id=response_id, reasoning_item_id_policy=reasoning_item_id_policy, store=store, + wrapper=wrapper, ) return persisted_count + saved_count diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index fe893cf88a..3c2a85380c 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -22,6 +22,7 @@ is_openai_model_name, select_compaction_candidate_items, ) +from agents.run_context import RunContextWrapper from agents.run_internal.items import ( TOOL_CALL_SESSION_DESCRIPTION_KEY, TOOL_CALL_SESSION_TITLE_KEY, @@ -510,7 +511,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 if self.add_calls == 1: await super().add_items(items[:1]) @@ -566,12 +572,22 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None and self.session_settings is not None: limit = self.session_settings.limit return await super().get_items(limit) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 if self.add_calls == 1: await super().add_items(items[:1]) @@ -624,7 +640,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 await super().add_items(items) @@ -674,7 +695,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 await super().add_items(items) @@ -725,7 +751,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 if self.add_calls == 1: await super().add_items(items[:1]) diff --git a/tests/memory/test_session.py b/tests/memory/test_session.py index f9cc324d2e..995e0a25d0 100644 --- a/tests/memory/test_session.py +++ b/tests/memory/test_session.py @@ -4,10 +4,12 @@ import sqlite3 import tempfile from pathlib import Path +from typing import Any import pytest from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem +from agents.run_context import RunContextWrapper from tests.fake_model import FakeModel from tests.test_responses import get_text_message @@ -640,7 +642,11 @@ async def test_session_add_items_exception_propagates_in_streamed(): """ session = SQLiteSession("test_exception_session") - async def _failing_add_items(_items): + async def _failing_add_items( + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: raise RuntimeError("Simulated session.add_items failure") session.add_items = _failing_add_items # type: ignore[method-assign] diff --git a/tests/memory/test_session_context_wrapper.py b/tests/memory/test_session_context_wrapper.py new file mode 100644 index 0000000000..8a56013c09 --- /dev/null +++ b/tests/memory/test_session_context_wrapper.py @@ -0,0 +1,433 @@ +"""Tests for passing the run context wrapper to Session methods (issue #2072).""" + +from __future__ import annotations + +import asyncio +import inspect +from dataclasses import dataclass +from typing import Any + +import pytest + +from agents import ( + Agent, + GuardrailFunctionOutput, + InputGuardrailTripwireTriggered, + RunConfig, + Runner, + SQLiteSession, + TResponseInputItem, + input_guardrail, +) +from agents.memory.session import SessionABC, session_method_accepts_wrapper +from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message + + +@dataclass +class UserInfo: + """Sample user-defined run context object.""" + + user_id: str = "user-123" + + +class ContextAwareSession(SessionABC): + """Session that opts into the wrapper parameter and records what it receives.""" + + def __init__(self, session_id: str = "context-aware"): + self.session_id = session_id + self._items: list[TResponseInputItem] = [] + self.get_items_wrappers: list[RunContextWrapper[Any] | None] = [] + self.get_items_limits: list[int | None] = [] + self.add_items_wrappers: list[RunContextWrapper[Any] | None] = [] + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + self.get_items_wrappers.append(wrapper) + self.get_items_limits.append(limit) + if limit is not None: + return self._items[-limit:] + return list(self._items) + + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + self.add_items_wrappers.append(wrapper) + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self._items.pop() if self._items else None + + async def clear_session(self) -> None: + self._items.clear() + + +class LegacySession(SessionABC): + """Session with pre-wrapper signatures, as third-party implementations may still have.""" + + def __init__(self, session_id: str = "legacy"): + self.session_id = session_id + self._items: list[TResponseInputItem] = [] + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: # type: ignore[override] + if limit is not None: + return self._items[-limit:] + return list(self._items) + + async def add_items(self, items: list[TResponseInputItem]) -> None: # type: ignore[override] + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self._items.pop() if self._items else None + + async def clear_session(self) -> None: + self._items.clear() + + +class VarKwargsSession(LegacySession): + """Session that accepts the wrapper through ``**kwargs`` rather than a named parameter.""" + + def __init__(self, session_id: str = "var-kwargs"): + super().__init__(session_id) + self.received_kwargs: list[dict[str, Any]] = [] + + async def get_items(self, limit: int | None = None, **kwargs: Any) -> list[TResponseInputItem]: + self.received_kwargs.append(kwargs) + return await super().get_items(limit) + + async def add_items(self, items: list[TResponseInputItem], **kwargs: Any) -> None: + self.received_kwargs.append(kwargs) + await super().add_items(items) + + +def _run_sync_wrapper(agent, input_data, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def run_agent_async(runner_method: str, agent, input_data, **kwargs): + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + elif runner_method == "run_sync": + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + elif runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + async for _ in result.stream_events(): + pass + return result + else: + raise ValueError(f"Unknown runner method: {runner_method}") + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_runner_passes_wrapper_to_context_aware_session(runner_method): + """Sessions that opt in receive the run context wrapper from every runner entrypoint.""" + session = ContextAwareSession() + context = UserInfo() + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Hello")]) + + result = await run_agent_async( + runner_method, agent, "Hi there", session=session, context=context + ) + assert result.final_output == "Hello" + + assert len(session.get_items_wrappers) > 0 + assert len(session.add_items_wrappers) > 0 + for wrapper in session.get_items_wrappers + session.add_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_runner_keeps_legacy_session_working(runner_method): + """Sessions without the wrapper parameter keep working unchanged.""" + session = LegacySession() + + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("San Francisco")]) + result1 = await run_agent_async( + runner_method, agent, "What city is the Golden Gate Bridge in?", session=session + ) + assert result1.final_output == "San Francisco" + + model.set_next_output([get_text_message("California")]) + result2 = await run_agent_async(runner_method, agent, "What state is it in?", session=session) + assert result2.final_output == "California" + + # The second turn must include the persisted history from the first turn. + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 + + +@pytest.mark.asyncio +async def test_runner_passes_wrapper_to_var_kwargs_session(): + """Sessions accepting **kwargs receive the wrapper through them.""" + session = VarKwargsSession() + context = UserInfo() + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Hello")]) + + await Runner.run(agent, "Hi there", session=session, context=context) + + wrappers = [kwargs["wrapper"] for kwargs in session.received_kwargs if "wrapper" in kwargs] + assert len(wrappers) > 0 + for wrapper in wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.asyncio +async def test_runner_without_explicit_context_passes_wrapper(): + """The wrapper is passed even when the caller does not provide a context object.""" + session = ContextAwareSession() + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Hello")]) + + await Runner.run(agent, "Hi there", session=session) + + assert len(session.add_items_wrappers) > 0 + for wrapper in session.add_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + + +def test_session_method_accepts_wrapper_helper(): + """The capability check recognizes opt-in signatures and rejects legacy ones.""" + context_aware = ContextAwareSession() + legacy = LegacySession() + var_kwargs = VarKwargsSession() + + assert session_method_accepts_wrapper(context_aware.get_items) is True + assert session_method_accepts_wrapper(context_aware.add_items) is True + assert session_method_accepts_wrapper(var_kwargs.get_items) is True + assert session_method_accepts_wrapper(var_kwargs.add_items) is True + assert session_method_accepts_wrapper(legacy.get_items) is False + assert session_method_accepts_wrapper(legacy.add_items) is False + # Callables without an introspectable signature must not be treated as opted in. + assert session_method_accepts_wrapper(max) is False + + +@pytest.mark.asyncio +async def test_sqlite_session_accepts_and_ignores_wrapper(): + """Built-in sessions accept the wrapper directly and behave the same with or without it.""" + session = SQLiteSession("direct-call-test") + try: + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + retrieved_with = await session.get_items(wrapper=wrapper) + retrieved_without = await session.get_items() + + assert retrieved_with == retrieved_without + assert len(retrieved_with) == 1 + finally: + session.close() + + +_BUILTIN_SESSION_SPECS = [ + ("agents.memory.sqlite_session", "SQLiteSession", None), + ("agents.memory.openai_conversations_session", "OpenAIConversationsSession", None), + ( + "agents.memory.openai_responses_compaction_session", + "OpenAIResponsesCompactionSession", + None, + ), + ("agents.extensions.memory.async_sqlite_session", "AsyncSQLiteSession", None), + ("agents.extensions.memory.advanced_sqlite_session", "AdvancedSQLiteSession", None), + ("agents.extensions.memory.encrypt_session", "EncryptedSession", "cryptography"), + ("agents.extensions.memory.redis_session", "RedisSession", "redis"), + ("agents.extensions.memory.sqlalchemy_session", "SQLAlchemySession", "sqlalchemy"), + ("agents.extensions.memory.dapr_session", "DaprSession", "dapr"), + ("agents.extensions.memory.mongodb_session", "MongoDBSession", "pymongo"), +] + + +@pytest.mark.parametrize( + "module_name,class_name,required_package", + _BUILTIN_SESSION_SPECS, + ids=[spec[1] for spec in _BUILTIN_SESSION_SPECS], +) +def test_builtin_sessions_expose_keyword_only_wrapper(module_name, class_name, required_package): + """Every built-in session implementation exposes the keyword-only wrapper parameter.""" + if required_package is not None: + pytest.importorskip(required_package) + module = pytest.importorskip(module_name) + session_cls = getattr(module, class_name) + + for method_name in ("get_items", "add_items"): + signature = inspect.signature(getattr(session_cls, method_name)) + parameter = signature.parameters.get("wrapper") + assert parameter is not None, f"{class_name}.{method_name} is missing wrapper" + assert parameter.kind is inspect.Parameter.KEYWORD_ONLY + assert parameter.default is None + + +@pytest.mark.asyncio +async def test_guardrail_trip_persists_input_with_wrapper(): + """The guardrail-trip persistence path forwards the wrapper to the session.""" + session = ContextAwareSession() + context = UserInfo() + + @input_guardrail + def always_trip(ctx, agent, input) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + model = FakeModel() + agent = Agent(name="test", model=model, input_guardrails=[always_trip]) + model.set_next_output([get_text_message("never returned")]) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "Hi there", session=session, context=context) + + assert len(session.add_items_wrappers) > 0 + for wrapper in session.add_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.asyncio +async def test_session_input_callback_path_passes_wrapper(): + """The history-merge callback path still forwards the wrapper on get_items.""" + session = ContextAwareSession() + context = UserInfo() + + def keep_everything( + history: list[TResponseInputItem], new_items: list[TResponseInputItem] + ) -> list[TResponseInputItem]: + return history + new_items + + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("first")]) + await Runner.run(agent, "Turn one", session=session, context=context) + + model.set_next_output([get_text_message("second")]) + await Runner.run( + agent, + "Turn two", + session=session, + context=context, + run_config=RunConfig(session_input_callback=keep_everything), + ) + + assert len(session.get_items_wrappers) >= 2 + for wrapper in session.get_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.asyncio +async def test_session_settings_limit_path_passes_wrapper(): + """The limited-history read passes both the limit and the wrapper to the session.""" + session = ContextAwareSession() + context = UserInfo() + + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("first")]) + await Runner.run(agent, "Turn one", session=session, context=context) + + model.set_next_output([get_text_message("second")]) + await Runner.run( + agent, + "Turn two", + session=session, + context=context, + run_config=RunConfig(session_settings=SessionSettings(limit=1)), + ) + + # The second run's history read must use the limit and still carry the wrapper. + assert session.get_items_limits[-1] == 1 + last_wrapper = session.get_items_wrappers[-1] + assert isinstance(last_wrapper, RunContextWrapper) + assert last_wrapper.context is context + + +@pytest.mark.asyncio +async def test_encrypted_session_forwards_wrapper_to_underlying_session(): + """EncryptedSession forwards the wrapper to underlying sessions that opt in.""" + pytest.importorskip("cryptography") + from agents.extensions.memory.encrypt_session import EncryptedSession + + underlying = ContextAwareSession() + session = EncryptedSession( + session_id="enc-forward", + underlying_session=underlying, + encryption_key="test-key", + ) + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + await session.get_items(wrapper=wrapper) + + assert underlying.add_items_wrappers == [wrapper] + assert len(underlying.get_items_wrappers) > 0 + assert all(received is wrapper for received in underlying.get_items_wrappers) + + +@pytest.mark.asyncio +async def test_encrypted_session_does_not_break_legacy_underlying_session(): + """EncryptedSession never passes the wrapper to underlying sessions that predate it.""" + pytest.importorskip("cryptography") + from agents.extensions.memory.encrypt_session import EncryptedSession + + underlying = LegacySession() + session = EncryptedSession( + session_id="enc-legacy", + underlying_session=underlying, + encryption_key="test-key", + ) + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + retrieved = await session.get_items(wrapper=wrapper) + + assert len(retrieved) == 1 + + +@pytest.mark.asyncio +async def test_compaction_session_forwards_wrapper_to_underlying_session(): + """OpenAIResponsesCompactionSession forwards the wrapper to opted-in underlying sessions.""" + from agents.memory.openai_responses_compaction_session import ( + OpenAIResponsesCompactionSession, + ) + + underlying = ContextAwareSession() + session = OpenAIResponsesCompactionSession("compaction-forward", underlying) + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + retrieved = await session.get_items(wrapper=wrapper) + + assert underlying.add_items_wrappers == [wrapper] + assert underlying.get_items_wrappers == [wrapper] + assert len(retrieved) == 1 diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index c5cc123034..737a56c304 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -328,10 +328,20 @@ class DummySession(Session): session_id = "sess_123" session_settings = SessionSettings() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: return None async def pop_item(self) -> TResponseInputItem | None: diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index eb22c70f14..874cddfd1d 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -7,7 +7,7 @@ from collections.abc import Callable from pathlib import Path from typing import Any, cast -from unittest.mock import patch +from unittest.mock import ANY, call, patch import httpx import pytest @@ -2123,12 +2123,22 @@ class DummyOpenAIConversationsSession(OpenAIConversationsSession): def __init__(self, history: list[TResponseInputItem]) -> None: self.history = history - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self.history) return self.history[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.history.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -2227,12 +2237,22 @@ class DummyOpenAIConversationsSession(OpenAIConversationsSession): def __init__(self, history: list[TResponseInputItem]) -> None: self.history = history - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self.history) return self.history[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.history.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -2280,12 +2300,22 @@ class DummyOpenAIConversationsSession(OpenAIConversationsSession): def __init__(self, history: list[TResponseInputItem]) -> None: self.history = history - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self.history) return self.history[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.history.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -2365,7 +2395,12 @@ def __init__(self) -> None: super().__init__() self.get_items_calls = 0 - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: self.get_items_calls += 1 if self.get_items_calls == 1: raise RuntimeError("temporary failure") @@ -2720,10 +2755,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -2807,10 +2852,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -2862,10 +2917,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -2908,10 +2973,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -4529,15 +4604,17 @@ async def echo_tool(text: str) -> str: }, ] + # The runner passes the run context wrapper to sessions that accept it, and the + # patched mock accepts any kwargs, so expect the wrapper keyword argument too. expected_calls = [ # First call is the initial input - (([expected_items[0]],),), + call([expected_items[0]], wrapper=ANY), # Second call is the first tool call and its result - (([expected_items[1], expected_items[2]],),), + call([expected_items[1], expected_items[2]], wrapper=ANY), # Third call is the second tool call and its result - (([expected_items[3], expected_items[4]],),), + call([expected_items[3], expected_items[4]], wrapper=ANY), # Fourth call is the final output - (([expected_items[5]],),), + call([expected_items[5]], wrapper=ANY), ] assert mock_add_items.call_args_list == expected_calls assert result.final_output == "Summary: Echoed foo and bar" diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 8ee3a55db4..b94b33005e 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1263,7 +1263,12 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: for item in items: if isinstance(item, dict): assert "id" not in item, "IDs should be stripped before saving" @@ -1272,7 +1277,12 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: ) self.saved.append(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -1916,6 +1926,7 @@ async def save_wrapper( response_id: str | None, reasoning_item_id_policy: str | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: observed_counts.append(persisted_count) result = await real_save_resumed( @@ -1925,6 +1936,7 @@ async def save_wrapper( response_id=response_id, reasoning_item_id_policy=reasoning_item_id_policy, store=store, + wrapper=wrapper, ) return int(result) diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index 94bcc97e9e..a9d3721259 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from agents.items import TResponseInputItem from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class SimpleListSession(Session): @@ -24,14 +25,24 @@ def __init__( # Mirror saved_items used by some tests for inspection. self.saved_items: list[TResponseInputItem] = self._items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self._items) if limit <= 0: return [] return self._items[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self._items.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -70,7 +81,12 @@ def __init__( super().__init__(session_id=session_id, history=history) self._ignore_ids_for_matching = True - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: sanitized: list[TResponseInputItem] = [] for item in items: if isinstance(item, dict): From 1f93475667243d6fd1e20c5ee1431817dc541cf2 Mon Sep 17 00:00:00 2001 From: jawwad-ali Date: Wed, 10 Jun 2026 19:48:46 +0500 Subject: [PATCH 2/4] fix: thread run context wrapper through session retry-rewind and compaction paths Addresses review feedback: a wrapper-aware session that scopes storage by wrapper.context was getting the wrapper on normal reads/writes but not on the conversation-retry rewind path or the compaction decorator's internal history reads/replacements, so those operated on the unscoped/default store. - Thread the wrapper through rewind_session_items, the tail-suffix rewind, cleanup verification, and popped-item restoration. - Thread the wrapper through OpenAIResponsesCompactionSession's run_compaction, candidate loading, and replace/restore helpers, and add it to the run_compaction protocol method. clear_session/pop_item keep their existing signatures, matching the get_items/add_items-only scope of this change. --- .../openai_responses_compaction_session.py | 86 ++++++++++++++----- src/agents/memory/session.py | 15 +++- src/agents/run_internal/run_loop.py | 8 +- .../run_internal/session_persistence.py | 42 ++++++--- tests/memory/test_session_context_wrapper.py | 47 ++++++++++ 5 files changed, 160 insertions(+), 38 deletions(-) diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index 48e2309156..bc360c7f39 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -159,7 +159,12 @@ def _resolve_compaction_mode_for_response( return "input" return _resolve_compaction_mode(mode, response_id=response_id, store=store) - async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: + async def run_compaction( + self, + args: OpenAIResponsesCompactionArgs | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Run compaction using responses.compact API.""" if args and args.get("response_id"): self._response_id = args["response_id"] @@ -184,7 +189,9 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None "when using previous_response_id compaction." ) - compaction_candidate_items, session_items = await self._ensure_compaction_candidates() + compaction_candidate_items, session_items = await self._ensure_compaction_candidates( + wrapper=wrapper + ) force = args.get("force", False) if args else False should_compact = force or self.should_trigger_compaction( @@ -220,10 +227,11 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None _normalize_compaction_output_items(compacted.output or []) ) - previous_items = await self._get_all_underlying_session_items() + previous_items = await self._get_all_underlying_session_items(wrapper=wrapper) await self._replace_underlying_session_items( output_items=output_items, previous_items=previous_items, + wrapper=wrapper, ) self._compaction_candidate_items = select_compaction_candidate_items(output_items) @@ -235,7 +243,7 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None f"candidates={len(self._compaction_candidate_items)})" ) - async def get_items( + async def _underlying_get_items( self, limit: int | None = None, *, @@ -249,37 +257,65 @@ async def get_items( return await self.underlying_session.get_items(limit, wrapper=wrapper) return await self.underlying_session.get_items(limit) - async def _get_all_underlying_session_items(self) -> list[TResponseInputItem]: - return await self.underlying_session.get_items(limit=_ALL_SESSION_ITEMS_LIMIT) + async def _underlying_add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.add_items + ): + await self.underlying_session.add_items(items, wrapper=wrapper) + return + await self.underlying_session.add_items(items) + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + return await self._underlying_get_items(limit, wrapper=wrapper) + + async def _get_all_underlying_session_items( + self, *, wrapper: RunContextWrapper[Any] | None = None + ) -> list[TResponseInputItem]: + return await self._underlying_get_items(_ALL_SESSION_ITEMS_LIMIT, wrapper=wrapper) async def _replace_underlying_session_items( self, *, output_items: list[TResponseInputItem], previous_items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, ) -> None: try: await self.underlying_session.clear_session() except Exception as clear_error: await self._restore_underlying_session_items_after_failed_clear( - previous_items, clear_error + previous_items, clear_error, wrapper=wrapper ) raise try: if output_items: - await self.underlying_session.add_items(output_items) + await self._underlying_add_items(output_items, wrapper=wrapper) except Exception as replacement_error: - await self._restore_underlying_session_items(previous_items, replacement_error) + await self._restore_underlying_session_items( + previous_items, replacement_error, wrapper=wrapper + ) raise async def _restore_underlying_session_items_after_failed_clear( self, previous_items: list[TResponseInputItem], clear_error: Exception, + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: try: - current_items = await self._get_all_underlying_session_items() + current_items = await self._get_all_underlying_session_items(wrapper=wrapper) except Exception: logger.warning( "Failed to inspect session history after compaction replacement clear failed.", @@ -291,7 +327,7 @@ async def _restore_underlying_session_items_after_failed_clear( return await self._restore_underlying_session_items( - previous_items, clear_error, clear_existing_items=False + previous_items, clear_error, clear_existing_items=False, wrapper=wrapper ) async def _restore_underlying_session_items( @@ -300,12 +336,13 @@ async def _restore_underlying_session_items( replacement_error: Exception, *, clear_existing_items: bool = True, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: try: if clear_existing_items: await self.underlying_session.clear_session() if previous_items: - await self.underlying_session.add_items(list(previous_items)) + await self._underlying_add_items(list(previous_items), wrapper=wrapper) except Exception: logger.warning( "Failed to restore session history after compaction replacement failed.", @@ -318,10 +355,18 @@ async def _restore_underlying_session_items( replacement_error, ) - async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None: + async def _defer_compaction( + self, + response_id: str, + store: bool | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: if self._deferred_response_id is not None: return - compaction_candidate_items, session_items = await self._ensure_compaction_candidates() + compaction_candidate_items, session_items = await self._ensure_compaction_candidates( + wrapper=wrapper + ) resolved_mode = self._resolve_compaction_mode_for_response( response_id=response_id, store=store, @@ -350,12 +395,7 @@ async def add_items( *, wrapper: RunContextWrapper[Any] | None = None, ) -> None: - if wrapper is not None and session_method_accepts_wrapper( - self.underlying_session.add_items - ): - await self.underlying_session.add_items(items, wrapper=wrapper) - else: - await self.underlying_session.add_items(items) + await self._underlying_add_items(items, wrapper=wrapper) if self._compaction_candidate_items is not None: new_items = _normalize_compaction_session_items(items) new_candidates = select_compaction_candidate_items(new_items) @@ -379,12 +419,16 @@ async def clear_session(self) -> None: async def _ensure_compaction_candidates( self, + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[list[TResponseInputItem], list[TResponseInputItem]]: """Lazy-load and cache compaction candidates.""" if self._compaction_candidate_items is not None and self._session_items is not None: return (self._compaction_candidate_items[:], self._session_items[:]) - history = _normalize_compaction_session_items(await self.underlying_session.get_items()) + history = _normalize_compaction_session_items( + await self._underlying_get_items(wrapper=wrapper) + ) candidates = select_compaction_candidate_items(history) self._compaction_candidate_items = candidates self._session_items = history diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 68ec558b93..888a7e6ae7 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -162,8 +162,19 @@ class OpenAIResponsesCompactionArgs(TypedDict, total=False): class OpenAIResponsesCompactionAwareSession(Session, Protocol): """Protocol for session implementations that support responses compaction.""" - async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: - """Run the compaction process for the session.""" + async def run_compaction( + self, + args: OpenAIResponsesCompactionArgs | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + """Run the compaction process for the session. + + Args: + args: Optional compaction arguments. + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. + """ ... diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45fe354be1..77626b144b 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1463,7 +1463,9 @@ def _tool_search_fingerprint(raw_item: Any) -> str: async def rewind_model_request() -> None: items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items(session, items_to_rewind, server_conversation_tracker) + await rewind_session_items( + session, items_to_rewind, server_conversation_tracker, wrapper=context_wrapper + ) if server_conversation_tracker is not None: server_conversation_tracker.rewind_input(filtered.input) @@ -1887,7 +1889,9 @@ async def get_new_response( async def rewind_model_request() -> None: items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items(session, items_to_rewind, server_conversation_tracker) + await rewind_session_items( + session, items_to_rewind, server_conversation_tracker, wrapper=context_wrapper + ) if server_conversation_tracker is not None: server_conversation_tracker.rewind_input(filtered.input) diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index dc7cf136b5..7db002b1a7 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -398,7 +398,10 @@ async def save_result_to_session( if has_local_tool_outputs: defer_compaction = getattr(session, "_defer_compaction", None) if callable(defer_compaction): - result = defer_compaction(response_id, store=store) + if session_method_accepts_wrapper(defer_compaction): + result = defer_compaction(response_id, store=store, wrapper=wrapper) + else: + result = defer_compaction(response_id, store=store) if inspect.isawaitable(result): await result logger.debug( @@ -424,7 +427,10 @@ async def save_result_to_session( } if store is not None: compaction_args["store"] = store - await session.run_compaction(compaction_args) + if session_method_accepts_wrapper(session.run_compaction): + await session.run_compaction(compaction_args, wrapper=wrapper) + else: + await session.run_compaction(compaction_args) return saved_run_items_count @@ -459,6 +465,8 @@ async def rewind_session_items( session: Session | None, items: Sequence[TResponseInputItem], server_tracker: OpenAIServerConversationTracker | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """ Best-effort helper to roll back items recently persisted to a session when a conversation @@ -499,6 +507,7 @@ async def rewind_session_items( "Skipping session rewind because the current tail does not match the retry-owned suffix" ), pop_failure_warning="Failed to rewind session item: %s", + wrapper=wrapper, ) if not rewound: return @@ -507,13 +516,14 @@ async def rewind_session_items( session, snapshot_serializations, ignore_ids_for_matching=ignore_ids_for_matching, + wrapper=wrapper, ) if session is None or server_tracker is None: return try: - latest_items = await session.get_items(limit=1) + latest_items = await _session_get_items(session, limit=1, wrapper=wrapper) except Exception as exc: logger.debug("Failed to peek session items while rewinding: %s", exc) return @@ -526,7 +536,7 @@ async def rewind_session_items( return try: - session_items = await session.get_items() + session_items = await _session_get_items(session, wrapper=wrapper) except Exception as exc: logger.debug("Failed to inspect session tail while stripping stray items: %s", exc) return @@ -554,6 +564,7 @@ async def rewind_session_items( "retry-owned conversation items" ), pop_failure_warning="Failed to strip stray session item: %s", + wrapper=wrapper, ) @@ -563,6 +574,7 @@ async def wait_for_session_cleanup( *, max_attempts: int = 5, ignore_ids_for_matching: bool = False, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """ Confirm that rewound items are no longer present in the session tail so the store stays @@ -575,7 +587,7 @@ async def wait_for_session_cleanup( for attempt in range(max_attempts): try: - tail_items = await session.get_items(limit=window) + tail_items = await _session_get_items(session, limit=window, wrapper=wrapper) except Exception as exc: logger.debug("Failed to verify session cleanup (attempt %d): %s", attempt + 1, exc) await asyncio.sleep(0.1 * (attempt + 1)) @@ -700,13 +712,16 @@ async def _rewind_session_tail_suffix( ignore_ids_for_matching: bool, mismatch_warning: str, pop_failure_warning: str, + wrapper: RunContextWrapper[Any] | None = None, ) -> bool: """Remove an exact serialized suffix from the session tail, aborting when the tail diverges.""" if not expected_serializations: return True try: - tail_items = await session.get_items(limit=len(expected_serializations)) + tail_items = await _session_get_items( + session, limit=len(expected_serializations), wrapper=wrapper + ) except Exception as exc: logger.warning(pop_failure_warning, exc) return False @@ -734,12 +749,12 @@ async def _rewind_session_tail_suffix( if inspect.isawaitable(result): result = await result except Exception as exc: - await _restore_popped_session_items(session, popped_items) + await _restore_popped_session_items(session, popped_items, wrapper=wrapper) logger.warning(pop_failure_warning, exc) return False if result is None: - await _restore_popped_session_items(session, popped_items) + await _restore_popped_session_items(session, popped_items, wrapper=wrapper) logger.warning(mismatch_warning) return False @@ -748,7 +763,7 @@ async def _rewind_session_tail_suffix( result, ignore_ids_for_matching=ignore_ids_for_matching ) if popped_serialized != expected: - await _restore_popped_session_items(session, popped_items) + await _restore_popped_session_items(session, popped_items, wrapper=wrapper) logger.warning(mismatch_warning) return False @@ -756,7 +771,10 @@ async def _rewind_session_tail_suffix( async def _restore_popped_session_items( - session: Session, popped_items: Sequence[TResponseInputItem] + session: Session, + popped_items: Sequence[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """Best-effort restoration for items popped during a failed rewind attempt.""" if not popped_items: @@ -767,9 +785,7 @@ async def _restore_popped_session_items( return try: - result = add_items(list(reversed(popped_items))) - if inspect.isawaitable(result): - await result + await _session_add_items(session, list(reversed(popped_items)), wrapper=wrapper) except Exception as exc: logger.warning("Failed to restore session items after a rewind mismatch: %s", exc) diff --git a/tests/memory/test_session_context_wrapper.py b/tests/memory/test_session_context_wrapper.py index 8a56013c09..a6c2e15f5a 100644 --- a/tests/memory/test_session_context_wrapper.py +++ b/tests/memory/test_session_context_wrapper.py @@ -431,3 +431,50 @@ async def test_compaction_session_forwards_wrapper_to_underlying_session(): assert underlying.add_items_wrappers == [wrapper] assert underlying.get_items_wrappers == [wrapper] assert len(retrieved) == 1 + + +@pytest.mark.asyncio +async def test_compaction_run_compaction_forwards_wrapper_to_underlying_reads(): + """run_compaction forwards the wrapper to the underlying session's history reads.""" + from agents.memory.openai_responses_compaction_session import ( + OpenAIResponsesCompactionSession, + ) + + underlying = ContextAwareSession() + await underlying.add_items([{"role": "user", "content": "hello"}]) + underlying.get_items_wrappers.clear() + + # Decline actual compaction so no OpenAI client call is made; only the candidate + # read path (which forwards the wrapper) runs. + session = OpenAIResponsesCompactionSession( + "compaction-run", + underlying, + should_trigger_compaction=lambda _info: False, + ) + wrapper = RunContextWrapper(context=UserInfo()) + + await session.run_compaction({"response_id": "resp_1"}, wrapper=wrapper) + + assert len(underlying.get_items_wrappers) > 0 + for received in underlying.get_items_wrappers: + assert received is wrapper + + +@pytest.mark.asyncio +async def test_rewind_session_items_forwards_wrapper(): + """The retry-rewind helper forwards the wrapper to the session it reads and pops.""" + from agents.run_internal.session_persistence import rewind_session_items + + session = ContextAwareSession() + item: TResponseInputItem = {"role": "user", "content": "to be rewound"} + await session.add_items([item]) + session.get_items_wrappers.clear() + + wrapper = RunContextWrapper(context=UserInfo()) + await rewind_session_items(session, [item], wrapper=wrapper) + + # The rewind read the tail with the wrapper and popped the matching item back off. + assert len(session.get_items_wrappers) > 0 + for received in session.get_items_wrappers: + assert received is wrapper + assert await session.get_items() == [] From 327ab24cdc30c4294512ecda43b590f81fbc80b3 Mon Sep 17 00:00:00 2001 From: jawwad-ali Date: Wed, 10 Jun 2026 19:58:21 +0500 Subject: [PATCH 3/4] fix: do not forward run context wrapper through OpenAIResponsesCompactionSession The compaction decorator rewrites history by clearing and replacing the underlying store during compaction. clear_session is intentionally not part of the get_items/add_items wrapper contract, so scoping only the reads and adds (but not the clear) would let compaction read/write a context-scoped store while clearing the default scope. Make the decorator consistent instead: it accepts the keyword-only wrapper to stay protocol-compatible but does not forward it to the underlying session, so all of its history operations use the same (default) scope. Document that wrapping a context-scoped session in the compaction decorator is unsupported for run-context scoping, while transparent proxies like EncryptedSession still forward the wrapper. --- docs/sessions/index.md | 2 + .../openai_responses_compaction_session.py | 93 +++++-------------- src/agents/memory/session.py | 15 +-- .../run_internal/session_persistence.py | 10 +- tests/memory/test_session_context_wrapper.py | 42 +++------ 5 files changed, 38 insertions(+), 124 deletions(-) diff --git a/docs/sessions/index.md b/docs/sessions/index.md index f36bf08d97..3f572d17a6 100644 --- a/docs/sessions/index.md +++ b/docs/sessions/index.md @@ -726,6 +726,8 @@ class ContextAwareSession(SessionABC): The `wrapper` parameter may be `None`, for example when session methods are called directly rather than through the runner, so implementations should always handle that case. Sessions that accept `**kwargs` on these methods also receive the wrapper through them. +Wrapping a context-aware session in `OpenAIResponsesCompactionSession` is not supported for run-context scoping: that decorator rewrites history by clearing and replacing the underlying store during compaction, which cannot be scoped consistently through the `get_items`/`add_items` wrapper, so it does not forward the run context to the underlying session. Transparent wrappers such as `EncryptedSession` do forward the wrapper to underlying sessions that opt in. + ## Community session implementations The community has developed additional session implementations: diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index bc360c7f39..abaccdbf5d 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -15,7 +15,6 @@ OpenAIResponsesCompactionArgs, OpenAIResponsesCompactionAwareSession, SessionABC, - session_method_accepts_wrapper, ) if TYPE_CHECKING: @@ -159,12 +158,7 @@ def _resolve_compaction_mode_for_response( return "input" return _resolve_compaction_mode(mode, response_id=response_id, store=store) - async def run_compaction( - self, - args: OpenAIResponsesCompactionArgs | None = None, - *, - wrapper: RunContextWrapper[Any] | None = None, - ) -> None: + async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: """Run compaction using responses.compact API.""" if args and args.get("response_id"): self._response_id = args["response_id"] @@ -189,9 +183,7 @@ async def run_compaction( "when using previous_response_id compaction." ) - compaction_candidate_items, session_items = await self._ensure_compaction_candidates( - wrapper=wrapper - ) + compaction_candidate_items, session_items = await self._ensure_compaction_candidates() force = args.get("force", False) if args else False should_compact = force or self.should_trigger_compaction( @@ -227,11 +219,10 @@ async def run_compaction( _normalize_compaction_output_items(compacted.output or []) ) - previous_items = await self._get_all_underlying_session_items(wrapper=wrapper) + previous_items = await self._get_all_underlying_session_items() await self._replace_underlying_session_items( output_items=output_items, previous_items=previous_items, - wrapper=wrapper, ) self._compaction_candidate_items = select_compaction_candidate_items(output_items) @@ -243,79 +234,48 @@ async def run_compaction( f"candidates={len(self._compaction_candidate_items)})" ) - async def _underlying_get_items( - self, - limit: int | None = None, - *, - wrapper: RunContextWrapper[Any] | None = None, - ) -> list[TResponseInputItem]: - # Forward the wrapper only when the underlying session opts in, so wrapping older - # custom sessions keeps working unchanged. - if wrapper is not None and session_method_accepts_wrapper( - self.underlying_session.get_items - ): - return await self.underlying_session.get_items(limit, wrapper=wrapper) - return await self.underlying_session.get_items(limit) - - async def _underlying_add_items( - self, - items: list[TResponseInputItem], - *, - wrapper: RunContextWrapper[Any] | None = None, - ) -> None: - if wrapper is not None and session_method_accepts_wrapper( - self.underlying_session.add_items - ): - await self.underlying_session.add_items(items, wrapper=wrapper) - return - await self.underlying_session.add_items(items) - async def get_items( self, limit: int | None = None, *, wrapper: RunContextWrapper[Any] | None = None, ) -> list[TResponseInputItem]: - return await self._underlying_get_items(limit, wrapper=wrapper) + # This decorator rewrites history via clear_session + add_items during compaction, + # which cannot be scoped consistently through a wrapper under the get_items/add_items + # contract, so it does not forward the run context to the underlying session. + return await self.underlying_session.get_items(limit) - async def _get_all_underlying_session_items( - self, *, wrapper: RunContextWrapper[Any] | None = None - ) -> list[TResponseInputItem]: - return await self._underlying_get_items(_ALL_SESSION_ITEMS_LIMIT, wrapper=wrapper) + async def _get_all_underlying_session_items(self) -> list[TResponseInputItem]: + return await self.underlying_session.get_items(limit=_ALL_SESSION_ITEMS_LIMIT) async def _replace_underlying_session_items( self, *, output_items: list[TResponseInputItem], previous_items: list[TResponseInputItem], - wrapper: RunContextWrapper[Any] | None = None, ) -> None: try: await self.underlying_session.clear_session() except Exception as clear_error: await self._restore_underlying_session_items_after_failed_clear( - previous_items, clear_error, wrapper=wrapper + previous_items, clear_error ) raise try: if output_items: - await self._underlying_add_items(output_items, wrapper=wrapper) + await self.underlying_session.add_items(output_items) except Exception as replacement_error: - await self._restore_underlying_session_items( - previous_items, replacement_error, wrapper=wrapper - ) + await self._restore_underlying_session_items(previous_items, replacement_error) raise async def _restore_underlying_session_items_after_failed_clear( self, previous_items: list[TResponseInputItem], clear_error: Exception, - *, - wrapper: RunContextWrapper[Any] | None = None, ) -> None: try: - current_items = await self._get_all_underlying_session_items(wrapper=wrapper) + current_items = await self._get_all_underlying_session_items() except Exception: logger.warning( "Failed to inspect session history after compaction replacement clear failed.", @@ -327,7 +287,7 @@ async def _restore_underlying_session_items_after_failed_clear( return await self._restore_underlying_session_items( - previous_items, clear_error, clear_existing_items=False, wrapper=wrapper + previous_items, clear_error, clear_existing_items=False ) async def _restore_underlying_session_items( @@ -336,13 +296,12 @@ async def _restore_underlying_session_items( replacement_error: Exception, *, clear_existing_items: bool = True, - wrapper: RunContextWrapper[Any] | None = None, ) -> None: try: if clear_existing_items: await self.underlying_session.clear_session() if previous_items: - await self._underlying_add_items(list(previous_items), wrapper=wrapper) + await self.underlying_session.add_items(list(previous_items)) except Exception: logger.warning( "Failed to restore session history after compaction replacement failed.", @@ -355,18 +314,10 @@ async def _restore_underlying_session_items( replacement_error, ) - async def _defer_compaction( - self, - response_id: str, - store: bool | None = None, - *, - wrapper: RunContextWrapper[Any] | None = None, - ) -> None: + async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None: if self._deferred_response_id is not None: return - compaction_candidate_items, session_items = await self._ensure_compaction_candidates( - wrapper=wrapper - ) + compaction_candidate_items, session_items = await self._ensure_compaction_candidates() resolved_mode = self._resolve_compaction_mode_for_response( response_id=response_id, store=store, @@ -395,7 +346,9 @@ async def add_items( *, wrapper: RunContextWrapper[Any] | None = None, ) -> None: - await self._underlying_add_items(items, wrapper=wrapper) + # See get_items: this decorator does not forward the run context to the underlying + # session, so history rewrites during compaction stay internally consistent. + await self.underlying_session.add_items(items) if self._compaction_candidate_items is not None: new_items = _normalize_compaction_session_items(items) new_candidates = select_compaction_candidate_items(new_items) @@ -419,16 +372,12 @@ async def clear_session(self) -> None: async def _ensure_compaction_candidates( self, - *, - wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[list[TResponseInputItem], list[TResponseInputItem]]: """Lazy-load and cache compaction candidates.""" if self._compaction_candidate_items is not None and self._session_items is not None: return (self._compaction_candidate_items[:], self._session_items[:]) - history = _normalize_compaction_session_items( - await self._underlying_get_items(wrapper=wrapper) - ) + history = _normalize_compaction_session_items(await self.underlying_session.get_items()) candidates = select_compaction_candidate_items(history) self._compaction_candidate_items = candidates self._session_items = history diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 888a7e6ae7..68ec558b93 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -162,19 +162,8 @@ class OpenAIResponsesCompactionArgs(TypedDict, total=False): class OpenAIResponsesCompactionAwareSession(Session, Protocol): """Protocol for session implementations that support responses compaction.""" - async def run_compaction( - self, - args: OpenAIResponsesCompactionArgs | None = None, - *, - wrapper: RunContextWrapper[Any] | None = None, - ) -> None: - """Run the compaction process for the session. - - Args: - args: Optional compaction arguments. - wrapper: Optional run context wrapper for the current run. Implementations - may ignore this parameter. - """ + async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: + """Run the compaction process for the session.""" ... diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index 7db002b1a7..05c5b7fe63 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -398,10 +398,7 @@ async def save_result_to_session( if has_local_tool_outputs: defer_compaction = getattr(session, "_defer_compaction", None) if callable(defer_compaction): - if session_method_accepts_wrapper(defer_compaction): - result = defer_compaction(response_id, store=store, wrapper=wrapper) - else: - result = defer_compaction(response_id, store=store) + result = defer_compaction(response_id, store=store) if inspect.isawaitable(result): await result logger.debug( @@ -427,10 +424,7 @@ async def save_result_to_session( } if store is not None: compaction_args["store"] = store - if session_method_accepts_wrapper(session.run_compaction): - await session.run_compaction(compaction_args, wrapper=wrapper) - else: - await session.run_compaction(compaction_args) + await session.run_compaction(compaction_args) return saved_run_items_count diff --git a/tests/memory/test_session_context_wrapper.py b/tests/memory/test_session_context_wrapper.py index a6c2e15f5a..e7644d8dfd 100644 --- a/tests/memory/test_session_context_wrapper.py +++ b/tests/memory/test_session_context_wrapper.py @@ -414,8 +414,13 @@ async def test_encrypted_session_does_not_break_legacy_underlying_session(): @pytest.mark.asyncio -async def test_compaction_session_forwards_wrapper_to_underlying_session(): - """OpenAIResponsesCompactionSession forwards the wrapper to opted-in underlying sessions.""" +async def test_compaction_session_accepts_but_does_not_forward_wrapper(): + """OpenAIResponsesCompactionSession accepts the wrapper but does not forward it. + + The decorator rewrites history via clear_session + add_items during compaction, which + cannot be scoped consistently through the get_items/add_items wrapper contract, so it + deliberately operates on the underlying session's default scope. + """ from agents.memory.openai_responses_compaction_session import ( OpenAIResponsesCompactionSession, ) @@ -425,39 +430,14 @@ async def test_compaction_session_forwards_wrapper_to_underlying_session(): wrapper = RunContextWrapper(context=UserInfo()) items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + # Accepting the keyword-only wrapper keeps it protocol-compatible and still works. await session.add_items(items, wrapper=wrapper) retrieved = await session.get_items(wrapper=wrapper) - assert underlying.add_items_wrappers == [wrapper] - assert underlying.get_items_wrappers == [wrapper] assert len(retrieved) == 1 - - -@pytest.mark.asyncio -async def test_compaction_run_compaction_forwards_wrapper_to_underlying_reads(): - """run_compaction forwards the wrapper to the underlying session's history reads.""" - from agents.memory.openai_responses_compaction_session import ( - OpenAIResponsesCompactionSession, - ) - - underlying = ContextAwareSession() - await underlying.add_items([{"role": "user", "content": "hello"}]) - underlying.get_items_wrappers.clear() - - # Decline actual compaction so no OpenAI client call is made; only the candidate - # read path (which forwards the wrapper) runs. - session = OpenAIResponsesCompactionSession( - "compaction-run", - underlying, - should_trigger_compaction=lambda _info: False, - ) - wrapper = RunContextWrapper(context=UserInfo()) - - await session.run_compaction({"response_id": "resp_1"}, wrapper=wrapper) - - assert len(underlying.get_items_wrappers) > 0 - for received in underlying.get_items_wrappers: - assert received is wrapper + # The wrapper is intentionally not propagated to the underlying session. + assert underlying.add_items_wrappers == [None] + assert underlying.get_items_wrappers == [None] @pytest.mark.asyncio From 6a4dca4e49b207cfe743dac5042c788ce4ea63a9 Mon Sep 17 00:00:00 2001 From: jawwad-ali Date: Wed, 10 Jun 2026 20:09:34 +0500 Subject: [PATCH 4/4] fix: do not forward run context wrapper through session retry rewind The retry-rewind path removes items via pop_item, which is outside the get_items/add_items run-context wrapper contract. Scoping only the tail reads (but not the pop) would let a wrapper-scoped session verify the scoped tail and then pop from the default scope, leaving the scoped items in place. Keep the rewind path unscoped so its reads and pops use the same (default) scope, mirroring the OpenAIResponsesCompactionSession decision. Documented that retry rewind and compaction run against the default scope for wrapper-scoped sessions. --- docs/sessions/index.md | 2 +- src/agents/run_internal/run_loop.py | 8 ++--- .../run_internal/session_persistence.py | 36 ++++++++----------- tests/memory/test_session_context_wrapper.py | 14 ++++---- 4 files changed, 24 insertions(+), 36 deletions(-) diff --git a/docs/sessions/index.md b/docs/sessions/index.md index 3f572d17a6..f0cf9c5371 100644 --- a/docs/sessions/index.md +++ b/docs/sessions/index.md @@ -726,7 +726,7 @@ class ContextAwareSession(SessionABC): The `wrapper` parameter may be `None`, for example when session methods are called directly rather than through the runner, so implementations should always handle that case. Sessions that accept `**kwargs` on these methods also receive the wrapper through them. -Wrapping a context-aware session in `OpenAIResponsesCompactionSession` is not supported for run-context scoping: that decorator rewrites history by clearing and replacing the underlying store during compaction, which cannot be scoped consistently through the `get_items`/`add_items` wrapper, so it does not forward the run context to the underlying session. Transparent wrappers such as `EncryptedSession` do forward the wrapper to underlying sessions that opt in. +The wrapper is forwarded only on the standard history read/write paths. Operations that rely on `pop_item` or `clear_session` — the conversation-retry rewind and the `OpenAIResponsesCompactionSession` decorator (which clears and replaces history during compaction) — are outside the `get_items`/`add_items` wrapper contract and would be inconsistent if only partially scoped, so they operate on the session's default scope and do not forward the wrapper. If you scope storage by `wrapper.context`, treat retry rewind and compaction as running against that default scope. Transparent wrappers such as `EncryptedSession` do forward the wrapper to underlying sessions that opt in. ## Community session implementations diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 77626b144b..45fe354be1 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1463,9 +1463,7 @@ def _tool_search_fingerprint(raw_item: Any) -> str: async def rewind_model_request() -> None: items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items( - session, items_to_rewind, server_conversation_tracker, wrapper=context_wrapper - ) + await rewind_session_items(session, items_to_rewind, server_conversation_tracker) if server_conversation_tracker is not None: server_conversation_tracker.rewind_input(filtered.input) @@ -1889,9 +1887,7 @@ async def get_new_response( async def rewind_model_request() -> None: items_to_rewind = session_items_to_rewind if session_items_to_rewind is not None else [] - await rewind_session_items( - session, items_to_rewind, server_conversation_tracker, wrapper=context_wrapper - ) + await rewind_session_items(session, items_to_rewind, server_conversation_tracker) if server_conversation_tracker is not None: server_conversation_tracker.rewind_input(filtered.input) diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index 05c5b7fe63..3487939543 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -459,12 +459,14 @@ async def rewind_session_items( session: Session | None, items: Sequence[TResponseInputItem], server_tracker: OpenAIServerConversationTracker | None = None, - *, - wrapper: RunContextWrapper[Any] | None = None, ) -> None: """ Best-effort helper to roll back items recently persisted to a session when a conversation retry is needed, so we do not accumulate duplicate inputs on lock errors. + + This path removes items via ``pop_item``, which is outside the ``get_items``/``add_items`` + run-context wrapper contract, so it does not forward the wrapper: a wrapper-scoped session's + retry rewind operates on the session's default scope. """ if session is None or not items: return @@ -501,7 +503,6 @@ async def rewind_session_items( "Skipping session rewind because the current tail does not match the retry-owned suffix" ), pop_failure_warning="Failed to rewind session item: %s", - wrapper=wrapper, ) if not rewound: return @@ -510,14 +511,13 @@ async def rewind_session_items( session, snapshot_serializations, ignore_ids_for_matching=ignore_ids_for_matching, - wrapper=wrapper, ) if session is None or server_tracker is None: return try: - latest_items = await _session_get_items(session, limit=1, wrapper=wrapper) + latest_items = await session.get_items(limit=1) except Exception as exc: logger.debug("Failed to peek session items while rewinding: %s", exc) return @@ -530,7 +530,7 @@ async def rewind_session_items( return try: - session_items = await _session_get_items(session, wrapper=wrapper) + session_items = await session.get_items() except Exception as exc: logger.debug("Failed to inspect session tail while stripping stray items: %s", exc) return @@ -558,7 +558,6 @@ async def rewind_session_items( "retry-owned conversation items" ), pop_failure_warning="Failed to strip stray session item: %s", - wrapper=wrapper, ) @@ -568,7 +567,6 @@ async def wait_for_session_cleanup( *, max_attempts: int = 5, ignore_ids_for_matching: bool = False, - wrapper: RunContextWrapper[Any] | None = None, ) -> None: """ Confirm that rewound items are no longer present in the session tail so the store stays @@ -581,7 +579,7 @@ async def wait_for_session_cleanup( for attempt in range(max_attempts): try: - tail_items = await _session_get_items(session, limit=window, wrapper=wrapper) + tail_items = await session.get_items(limit=window) except Exception as exc: logger.debug("Failed to verify session cleanup (attempt %d): %s", attempt + 1, exc) await asyncio.sleep(0.1 * (attempt + 1)) @@ -706,16 +704,13 @@ async def _rewind_session_tail_suffix( ignore_ids_for_matching: bool, mismatch_warning: str, pop_failure_warning: str, - wrapper: RunContextWrapper[Any] | None = None, ) -> bool: """Remove an exact serialized suffix from the session tail, aborting when the tail diverges.""" if not expected_serializations: return True try: - tail_items = await _session_get_items( - session, limit=len(expected_serializations), wrapper=wrapper - ) + tail_items = await session.get_items(limit=len(expected_serializations)) except Exception as exc: logger.warning(pop_failure_warning, exc) return False @@ -743,12 +738,12 @@ async def _rewind_session_tail_suffix( if inspect.isawaitable(result): result = await result except Exception as exc: - await _restore_popped_session_items(session, popped_items, wrapper=wrapper) + await _restore_popped_session_items(session, popped_items) logger.warning(pop_failure_warning, exc) return False if result is None: - await _restore_popped_session_items(session, popped_items, wrapper=wrapper) + await _restore_popped_session_items(session, popped_items) logger.warning(mismatch_warning) return False @@ -757,7 +752,7 @@ async def _rewind_session_tail_suffix( result, ignore_ids_for_matching=ignore_ids_for_matching ) if popped_serialized != expected: - await _restore_popped_session_items(session, popped_items, wrapper=wrapper) + await _restore_popped_session_items(session, popped_items) logger.warning(mismatch_warning) return False @@ -765,10 +760,7 @@ async def _rewind_session_tail_suffix( async def _restore_popped_session_items( - session: Session, - popped_items: Sequence[TResponseInputItem], - *, - wrapper: RunContextWrapper[Any] | None = None, + session: Session, popped_items: Sequence[TResponseInputItem] ) -> None: """Best-effort restoration for items popped during a failed rewind attempt.""" if not popped_items: @@ -779,7 +771,9 @@ async def _restore_popped_session_items( return try: - await _session_add_items(session, list(reversed(popped_items)), wrapper=wrapper) + result = add_items(list(reversed(popped_items))) + if inspect.isawaitable(result): + await result except Exception as exc: logger.warning("Failed to restore session items after a rewind mismatch: %s", exc) diff --git a/tests/memory/test_session_context_wrapper.py b/tests/memory/test_session_context_wrapper.py index e7644d8dfd..06d471bb53 100644 --- a/tests/memory/test_session_context_wrapper.py +++ b/tests/memory/test_session_context_wrapper.py @@ -441,8 +441,9 @@ async def test_compaction_session_accepts_but_does_not_forward_wrapper(): @pytest.mark.asyncio -async def test_rewind_session_items_forwards_wrapper(): - """The retry-rewind helper forwards the wrapper to the session it reads and pops.""" +async def test_rewind_session_items_does_not_forward_wrapper(): + """The retry-rewind helper removes items via pop_item, which is outside the wrapper + contract, so it operates on the session's default scope and does not forward the wrapper.""" from agents.run_internal.session_persistence import rewind_session_items session = ContextAwareSession() @@ -450,11 +451,8 @@ async def test_rewind_session_items_forwards_wrapper(): await session.add_items([item]) session.get_items_wrappers.clear() - wrapper = RunContextWrapper(context=UserInfo()) - await rewind_session_items(session, [item], wrapper=wrapper) + await rewind_session_items(session, [item]) - # The rewind read the tail with the wrapper and popped the matching item back off. - assert len(session.get_items_wrappers) > 0 - for received in session.get_items_wrappers: - assert received is wrapper + # The rewind still works (the matching item is popped) but no wrapper is forwarded. assert await session.get_items() == [] + assert all(received is None for received in session.get_items_wrappers)