Skip to content

Commit d493ce3

Browse files
committed
fix: preserve legacy wrapper compatibility in EncryptedSession
1 parent a5c1df8 commit d493ce3

2 files changed

Lines changed: 75 additions & 2 deletions

File tree

src/agents/extensions/memory/encrypt_session.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from __future__ import annotations
2929

3030
import base64
31+
import inspect
3132
import json
3233
from typing import Any, cast
3334

@@ -97,6 +98,35 @@ def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
9798
)
9899

99100

101+
def _method_accepts_wrapper(method: Any) -> bool:
102+
try:
103+
parameters = tuple(inspect.signature(method).parameters.values())
104+
except (TypeError, ValueError):
105+
return False
106+
107+
return any(
108+
parameter.kind is inspect.Parameter.VAR_KEYWORD or parameter.name == "wrapper"
109+
for parameter in parameters
110+
)
111+
112+
113+
def _method_accepts_limit(method: Any) -> bool:
114+
try:
115+
parameters = tuple(inspect.signature(method).parameters.values())
116+
except (TypeError, ValueError):
117+
return False
118+
119+
return any(
120+
(
121+
parameter.kind
122+
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
123+
and parameter.name == "limit"
124+
)
125+
or parameter.kind is inspect.Parameter.VAR_KEYWORD
126+
for parameter in parameters
127+
)
128+
129+
100130
class EncryptedSession(SessionABC):
101131
"""Encrypted wrapper for Session implementations with TTL-based expiration.
102132
@@ -176,8 +206,10 @@ async def get_items(
176206
limit: int | None = None,
177207
wrapper: RunContextWrapper[Any] | None = None,
178208
) -> list[TResponseInputItem]:
179-
if wrapper is not None:
209+
if wrapper is not None and _method_accepts_wrapper(self.underlying_session.get_items):
180210
encrypted_items = await self.underlying_session.get_items(limit, wrapper=wrapper)
211+
elif limit is None and not _method_accepts_limit(self.underlying_session.get_items):
212+
encrypted_items = await self.underlying_session.get_items()
181213
else:
182214
encrypted_items = await self.underlying_session.get_items(limit)
183215
valid_items: list[TResponseInputItem] = []
@@ -194,7 +226,7 @@ async def add_items(
194226
) -> None:
195227
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
196228
wrapped_items = cast(list[TResponseInputItem], wrapped)
197-
if wrapper is not None:
229+
if wrapper is not None and _method_accepts_wrapper(self.underlying_session.add_items):
198230
await self.underlying_session.add_items(
199231
wrapped_items,
200232
wrapper=wrapper,

tests/extensions/memory/test_encrypt_session.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from agents import Agent, Runner, SQLiteSession, TResponseInputItem
1313
from agents.extensions.memory.encrypt_session import EncryptedSession
14+
from agents.memory.session import SessionABC
1415
from tests.fake_model import FakeModel
1516
from tests.test_responses import get_text_message
1617

@@ -111,6 +112,46 @@ async def test_encrypted_session_with_runner(
111112
underlying_session.close()
112113

113114

115+
async def test_encrypted_session_preserves_legacy_underlying_signatures(
116+
agent: Agent,
117+
encryption_key: str,
118+
):
119+
class LegacyUnderlyingSession(SessionABC):
120+
def __init__(self) -> None:
121+
self.session_id = "test_session"
122+
self.items: list[TResponseInputItem] = []
123+
124+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
125+
if limit is None:
126+
return list(self.items)
127+
return list(self.items[-limit:]) if limit > 0 else []
128+
129+
async def add_items(self, items: list[TResponseInputItem]) -> None:
130+
self.items.extend(items)
131+
132+
async def pop_item(self) -> TResponseInputItem | None:
133+
if not self.items:
134+
return None
135+
return self.items.pop()
136+
137+
async def clear_session(self) -> None:
138+
self.items.clear()
139+
140+
legacy_underlying = LegacyUnderlyingSession()
141+
session = EncryptedSession(
142+
session_id="test_session",
143+
underlying_session=legacy_underlying,
144+
encryption_key=encryption_key,
145+
)
146+
147+
assert isinstance(agent.model, FakeModel)
148+
agent.model.set_next_output([get_text_message("Hello")])
149+
result = await Runner.run(agent, "Hi there", session=session)
150+
151+
assert result.final_output == "Hello"
152+
assert legacy_underlying.items
153+
154+
114155
async def test_encrypted_session_pop_item(encryption_key: str, underlying_session: SQLiteSession):
115156
"""Test pop_item functionality."""
116157
session = EncryptedSession(

0 commit comments

Comments
 (0)