Skip to content

Commit 6f96adb

Browse files
committed
fix: unify session get_items compatibility fallback
1 parent 7ca8a3b commit 6f96adb

6 files changed

Lines changed: 216 additions & 30 deletions

File tree

src/agents/extensions/memory/encrypt_session.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,33 @@ def _method_accepts_limit(method: Any) -> bool:
127127
)
128128

129129

130+
async def _delegate_get_items(
131+
session: SessionABC,
132+
limit: int | None = None,
133+
*,
134+
wrapper: RunContextWrapper[Any] | None = None,
135+
) -> list[TResponseInputItem]:
136+
accepts_wrapper = wrapper is not None and _method_accepts_wrapper(session.get_items)
137+
accepts_limit = _method_accepts_limit(session.get_items)
138+
139+
if limit is None:
140+
if accepts_wrapper:
141+
return await session.get_items(wrapper=wrapper)
142+
return await session.get_items()
143+
144+
if accepts_limit:
145+
if accepts_wrapper:
146+
return await session.get_items(limit=limit, wrapper=wrapper)
147+
return await session.get_items(limit=limit)
148+
149+
if accepts_wrapper:
150+
items = await session.get_items(wrapper=wrapper)
151+
else:
152+
items = await session.get_items()
153+
154+
return items[-limit:] if limit > 0 else []
155+
156+
130157
class EncryptedSession(SessionABC):
131158
"""Encrypted wrapper for Session implementations with TTL-based expiration.
132159
@@ -206,18 +233,11 @@ async def get_items(
206233
limit: int | None = None,
207234
wrapper: RunContextWrapper[Any] | None = None,
208235
) -> list[TResponseInputItem]:
209-
accepts_wrapper = wrapper is not None and _method_accepts_wrapper(
210-
self.underlying_session.get_items
236+
encrypted_items = await _delegate_get_items(
237+
self.underlying_session,
238+
limit=limit,
239+
wrapper=wrapper,
211240
)
212-
if limit is None:
213-
if accepts_wrapper:
214-
encrypted_items = await self.underlying_session.get_items(wrapper=wrapper)
215-
else:
216-
encrypted_items = await self.underlying_session.get_items()
217-
elif accepts_wrapper:
218-
encrypted_items = await self.underlying_session.get_items(limit, wrapper=wrapper)
219-
else:
220-
encrypted_items = await self.underlying_session.get_items(limit)
221241
valid_items: list[TResponseInputItem] = []
222242
for enc in encrypted_items:
223243
item = self._unwrap(enc)

src/agents/memory/openai_responses_compaction_session.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,33 @@ def _method_accepts_limit(method: Any) -> bool:
5656
)
5757

5858

59+
async def _delegate_get_items(
60+
session: Session,
61+
limit: int | None = None,
62+
*,
63+
wrapper: RunContextWrapper[Any] | None = None,
64+
) -> list[TResponseInputItem]:
65+
accepts_wrapper = wrapper is not None and _method_accepts_wrapper(session.get_items)
66+
accepts_limit = _method_accepts_limit(session.get_items)
67+
68+
if limit is None:
69+
if accepts_wrapper:
70+
return await session.get_items(wrapper=wrapper)
71+
return await session.get_items()
72+
73+
if accepts_limit:
74+
if accepts_wrapper:
75+
return await session.get_items(limit=limit, wrapper=wrapper)
76+
return await session.get_items(limit=limit)
77+
78+
if accepts_wrapper:
79+
items = await session.get_items(wrapper=wrapper)
80+
else:
81+
items = await session.get_items()
82+
83+
return items[-limit:] if limit > 0 else []
84+
85+
5986
def select_compaction_candidate_items(
6087
items: list[TResponseInputItem],
6188
) -> list[TResponseInputItem]:
@@ -273,19 +300,12 @@ async def get_items(
273300
limit: int | None = None,
274301
wrapper: RunContextWrapper[Any] | None = None,
275302
) -> list[TResponseInputItem]:
276-
accepts_wrapper = wrapper is not None and _method_accepts_wrapper(
277-
self.underlying_session.get_items
303+
return await _delegate_get_items(
304+
self.underlying_session,
305+
limit=limit,
306+
wrapper=wrapper,
278307
)
279308

280-
if limit is None:
281-
if accepts_wrapper:
282-
return await self.underlying_session.get_items(wrapper=wrapper)
283-
return await self.underlying_session.get_items()
284-
285-
if accepts_wrapper and _method_accepts_limit(self.underlying_session.get_items):
286-
return await self.underlying_session.get_items(limit=limit, wrapper=wrapper)
287-
return await self.underlying_session.get_items(limit)
288-
289309
async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None:
290310
if self._deferred_response_id is not None:
291311
return

src/agents/run_internal/session_persistence.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,24 @@ async def _session_get_items(
8383
wrapper: RunContextWrapper[Any] | None = None,
8484
) -> list[TResponseInputItem]:
8585
accepts_wrapper = wrapper is not None and _session_method_accepts_wrapper(session.get_items)
86+
accepts_limit = _session_method_accepts_limit(session.get_items)
8687

8788
if limit is None:
8889
if accepts_wrapper:
8990
return await session.get_items(wrapper=wrapper)
9091
return await session.get_items()
9192

93+
if accepts_limit:
94+
if accepts_wrapper:
95+
return await session.get_items(limit=limit, wrapper=wrapper)
96+
return await session.get_items(limit=limit)
97+
9298
if accepts_wrapper:
93-
return await session.get_items(limit=limit, wrapper=wrapper)
94-
return await session.get_items(limit=limit)
99+
items = await session.get_items(wrapper=wrapper)
100+
else:
101+
items = await session.get_items()
102+
103+
return items[-limit:] if limit > 0 else []
95104

96105

97106
async def _session_add_items(
@@ -539,7 +548,7 @@ async def rewind_session_items(
539548
return
540549

541550
try:
542-
latest_items = await session.get_items(limit=1)
551+
latest_items = await _session_get_items(session, limit=1)
543552
except Exception as exc:
544553
logger.debug("Failed to peek session items while rewinding: %s", exc)
545554
return
@@ -587,7 +596,7 @@ async def wait_for_session_cleanup(
587596

588597
for attempt in range(max_attempts):
589598
try:
590-
tail_items = await session.get_items(limit=window)
599+
tail_items = await _session_get_items(session, limit=window)
591600
except Exception as exc:
592601
logger.debug("Failed to verify session cleanup (attempt %d): %s", attempt + 1, exc)
593602
await asyncio.sleep(0.1 * (attempt + 1))

tests/extensions/memory/test_encrypt_session.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from agents import Agent, Runner, SQLiteSession, TResponseInputItem
1515
from agents.extensions.memory.encrypt_session import EncryptedSession
1616
from agents.memory.session import SessionABC
17+
from agents.run_context import RunContextWrapper
1718
from tests.fake_model import FakeModel
1819
from tests.test_responses import get_text_message
1920

@@ -114,6 +115,46 @@ async def test_encrypted_session_with_runner(
114115
underlying_session.close()
115116

116117

118+
async def test_encrypted_session_preserves_wrapper_only_underlying_with_limit(
119+
encryption_key: str,
120+
):
121+
class WrapperOnlyUnderlyingSession:
122+
def __init__(self) -> None:
123+
self.session_id = "test_session"
124+
self.items: list[TResponseInputItem] = []
125+
self.get_wrappers: list[object | None] = []
126+
127+
async def get_items(self, wrapper: object = None) -> list[TResponseInputItem]:
128+
self.get_wrappers.append(wrapper)
129+
return list(self.items)
130+
131+
async def add_items(self, items: list[TResponseInputItem]) -> None:
132+
self.items.extend(items)
133+
134+
async def pop_item(self) -> TResponseInputItem | None:
135+
return None
136+
137+
async def clear_session(self) -> None:
138+
self.items.clear()
139+
140+
underlying = WrapperOnlyUnderlyingSession()
141+
underlying.items = [
142+
{"role": "user", "content": "one"},
143+
{"role": "assistant", "content": "two"},
144+
]
145+
session = EncryptedSession(
146+
session_id="test_session",
147+
underlying_session=cast(SessionABC, underlying),
148+
encryption_key=encryption_key,
149+
)
150+
151+
wrapper = RunContextWrapper(context={"request_id": "encrypt"})
152+
items = await session.get_items(limit=1, wrapper=wrapper)
153+
154+
assert items[-1].get("content") == "two"
155+
assert underlying.get_wrappers == [wrapper]
156+
157+
117158
async def test_encrypted_session_preserves_legacy_underlying_signatures(
118159
agent: Agent,
119160
encryption_key: str,

tests/memory/test_openai_responses_compaction_session.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_openai_model_name,
2020
select_compaction_candidate_items,
2121
)
22+
from agents.run_context import RunContextWrapper
2223
from tests.fake_model import FakeModel
2324
from tests.test_responses import get_function_tool, get_function_tool_call, get_text_message
2425
from tests.utils.simple_session import SimpleListSession
@@ -131,18 +132,59 @@ async def pop_item(self) -> TResponseInputItem | None:
131132
async def clear_session(self) -> None:
132133
return None
133134

134-
underlying = cast(Session, WrapperOnlySession())
135+
underlying = WrapperOnlySession()
135136
session = OpenAIResponsesCompactionSession(
136137
session_id="test",
137-
underlying_session=underlying,
138+
underlying_session=cast(Session, underlying),
138139
)
139140

140-
wrapper = SimpleNamespace(context={"request_id": "abc"})
141+
wrapper = RunContextWrapper(context={"request_id": "abc"})
141142
items = await session.get_items(wrapper=wrapper)
142143

143144
assert items == []
144145
assert underlying.calls == [(None, wrapper)]
145146

147+
@pytest.mark.asyncio
148+
async def test_get_items_with_limit_preserves_wrapper_only_delegate_shape(self) -> None:
149+
class WrapperOnlySession:
150+
session_id = "test-session"
151+
152+
def __init__(self) -> None:
153+
self.calls: list[tuple[int | None, Any]] = []
154+
self.items: list[TResponseInputItem] = [
155+
cast(TResponseInputItem, {"role": "user", "content": "one"}),
156+
cast(TResponseInputItem, {"role": "assistant", "content": "two"}),
157+
]
158+
159+
async def get_items(
160+
self,
161+
wrapper: Any = None,
162+
) -> list[TResponseInputItem]:
163+
self.calls.append((None, wrapper))
164+
return list(self.items)
165+
166+
async def add_items(self, items: list[TResponseInputItem]) -> None:
167+
return None
168+
169+
async def pop_item(self) -> TResponseInputItem | None:
170+
return None
171+
172+
async def clear_session(self) -> None:
173+
return None
174+
175+
underlying = WrapperOnlySession()
176+
session = OpenAIResponsesCompactionSession(
177+
session_id="test",
178+
underlying_session=cast(Session, underlying),
179+
)
180+
181+
wrapper = RunContextWrapper(context={"request_id": "abc"})
182+
items = await session.get_items(limit=1, wrapper=wrapper)
183+
184+
assert len(items) == 1
185+
assert items[0].get("content") == "two"
186+
assert underlying.calls == [(None, wrapper)]
187+
146188
@pytest.mark.asyncio
147189
async def test_add_items_delegates(self) -> None:
148190
mock_session = self.create_mock_session()

tests/test_session.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import tempfile
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, cast
77

88
import pytest
99

@@ -166,6 +166,35 @@ async def clear_session(self) -> None:
166166
self.items.clear()
167167

168168

169+
class WrapperOnlySession:
170+
session_id = "wrapper-only"
171+
session_settings: SessionSettings | None = None
172+
173+
def __init__(self) -> None:
174+
self.items: list[TResponseInputItem] = []
175+
self.get_wrappers: list[RunContextWrapper[Any] | None] = []
176+
177+
async def get_items(
178+
self,
179+
wrapper: RunContextWrapper[Any] | None = None,
180+
) -> list[TResponseInputItem]:
181+
self.get_wrappers.append(wrapper)
182+
return list(self.items)
183+
184+
async def add_items(
185+
self,
186+
items: list[TResponseInputItem],
187+
wrapper: RunContextWrapper[Any] | None = None,
188+
) -> None:
189+
self.items.extend(items)
190+
191+
async def pop_item(self) -> TResponseInputItem | None:
192+
return self.items.pop() if self.items else None
193+
194+
async def clear_session(self) -> None:
195+
self.items.clear()
196+
197+
169198
# Parametrized tests for different runner methods
170199
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
171200
@pytest.mark.asyncio
@@ -351,6 +380,31 @@ async def test_get_items_preserves_default_limit_when_none_is_unset() -> None:
351380
assert session.get_call_count == 1
352381

353382

383+
@pytest.mark.asyncio
384+
async def test_get_items_with_limit_preserves_wrapper_only_delegate_shape() -> None:
385+
session = WrapperOnlySession()
386+
session.items = [
387+
{"role": "user", "content": "one"},
388+
{"role": "assistant", "content": "two"},
389+
]
390+
wrapper = RunContextWrapper(context={"request_id": "wrapper-only"})
391+
392+
prepared_input, _ = await prepare_input_with_session(
393+
"new",
394+
cast(SessionABC, session),
395+
session_input_callback=None,
396+
session_settings=SessionSettings(limit=1),
397+
wrapper=wrapper,
398+
)
399+
400+
assert isinstance(prepared_input, list)
401+
first = prepared_input[0]
402+
second = prepared_input[1]
403+
assert isinstance(first, dict) and first.get("content") == "two"
404+
assert isinstance(second, dict) and second.get("content") == "new"
405+
assert session.get_wrappers == [wrapper]
406+
407+
354408
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
355409
@pytest.mark.asyncio
356410
async def test_session_memory_different_sessions_parametrized(runner_method):

0 commit comments

Comments
 (0)