Skip to content

Commit e6472db

Browse files
committed
fix: preserve legacy get_items semantics
1 parent d493ce3 commit e6472db

5 files changed

Lines changed: 92 additions & 7 deletions

File tree

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ async def add_items(
140140
async def get_items(
141141
self,
142142
limit: int | None = None,
143-
wrapper: RunContextWrapper[Any] | None = None,
144143
branch_id: str | None = None,
144+
wrapper: RunContextWrapper[Any] | None = None,
145145
) -> list[TResponseInputItem]:
146146
"""Get items from current or specified branch.
147147

src/agents/extensions/memory/encrypt_session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,16 @@ async def get_items(
206206
limit: int | None = None,
207207
wrapper: RunContextWrapper[Any] | None = None,
208208
) -> list[TResponseInputItem]:
209-
if wrapper is not None and _method_accepts_wrapper(self.underlying_session.get_items):
209+
accepts_wrapper = wrapper is not None and _method_accepts_wrapper(
210+
self.underlying_session.get_items
211+
)
212+
if limit is None:
213+
if accepts_wrapper:
214+
encrypted_items = await self.underlying_session.get_items(wrapper=wrapper)
215+
else:
216+
encrypted_items = await self.underlying_session.get_items()
217+
elif accepts_wrapper:
210218
encrypted_items = await self.underlying_session.get_items(limit, wrapper=wrapper)
211-
elif limit is None and not _method_accepts_limit(self.underlying_session.get_items):
212-
encrypted_items = await self.underlying_session.get_items()
213219
else:
214220
encrypted_items = await self.underlying_session.get_items(limit)
215221
valid_items: list[TResponseInputItem] = []

src/agents/run_internal/session_persistence.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,15 @@ async def _session_get_items(
8282
*,
8383
wrapper: RunContextWrapper[Any] | None = None,
8484
) -> list[TResponseInputItem]:
85-
if wrapper is not None and _session_method_accepts_wrapper(session.get_items):
86-
return await session.get_items(limit=limit, wrapper=wrapper)
87-
if limit is None and not _session_method_accepts_limit(session.get_items):
85+
accepts_wrapper = wrapper is not None and _session_method_accepts_wrapper(session.get_items)
86+
87+
if limit is None:
88+
if accepts_wrapper:
89+
return await session.get_items(wrapper=wrapper)
8890
return await session.get_items()
91+
92+
if accepts_wrapper:
93+
return await session.get_items(limit=limit, wrapper=wrapper)
8994
return await session.get_items(limit=limit)
9095

9196

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
import tempfile
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
from agents.extensions.memory.advanced_sqlite_session import AdvancedSQLiteSession
9+
10+
pytestmark = pytest.mark.asyncio
11+
12+
13+
async def test_advanced_sqlite_get_items_preserves_branch_id_positional_argument() -> None:
14+
with tempfile.TemporaryDirectory() as temp_dir:
15+
db_path = Path(temp_dir) / "advanced.db"
16+
session = AdvancedSQLiteSession(session_id="test", db_path=db_path, create_tables=True)
17+
18+
await session.add_items([
19+
{"role": "user", "content": "main message"},
20+
])
21+
branch_id = await session.create_branch_from_turn(1, "branch-a")
22+
assert branch_id == "branch-a"
23+
await session.add_items([
24+
{"role": "user", "content": "branch message"},
25+
])
26+
await session.switch_to_branch("main")
27+
28+
branch_items = await session.get_items(50, "branch-a")
29+
contents = [item.get("content") for item in branch_items if isinstance(item, dict)]
30+
31+
assert "branch message" in contents
32+
assert "main message" not in contents

tests/test_session.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem
1111
from agents.memory.session import SessionABC
1212
from agents.run_context import RunContextWrapper
13+
from agents.run_internal.session_persistence import prepare_input_with_session
1314

1415
from .fake_model import FakeModel
1516
from .test_responses import get_text_message
@@ -133,6 +134,27 @@ async def clear_session(self) -> None:
133134
self.items.clear()
134135

135136

137+
class DefaultLimitedSession:
138+
session_id = "default-limited"
139+
140+
def __init__(self) -> None:
141+
self.items: list[TResponseInputItem] = []
142+
self.get_call_count = 0
143+
144+
async def get_items(self, limit: int = 1) -> list[TResponseInputItem]:
145+
self.get_call_count += 1
146+
return list(self.items[-limit:]) if limit > 0 else []
147+
148+
async def add_items(self, items: list[TResponseInputItem]) -> None:
149+
self.items.extend(items)
150+
151+
async def pop_item(self) -> TResponseInputItem | None:
152+
return self.items.pop() if self.items else None
153+
154+
async def clear_session(self) -> None:
155+
self.items.clear()
156+
157+
136158
# Parametrized tests for different runner methods
137159
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
138160
@pytest.mark.asyncio
@@ -296,6 +318,26 @@ async def test_legacy_session_without_limit_keyword_remains_compatible(runner_me
296318
assert session.items
297319

298320

321+
@pytest.mark.asyncio
322+
async def test_get_items_preserves_default_limit_when_none_is_unset() -> None:
323+
session = DefaultLimitedSession()
324+
session.items = [
325+
{"role": "user", "content": "one"},
326+
{"role": "assistant", "content": "two"},
327+
]
328+
329+
prepared_input, _ = await prepare_input_with_session(
330+
"new",
331+
session,
332+
session_input_callback=None,
333+
)
334+
335+
assert isinstance(prepared_input, list)
336+
assert prepared_input[0]["content"] == "two"
337+
assert prepared_input[1]["content"] == "new"
338+
assert session.get_call_count == 1
339+
340+
299341
@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
300342
@pytest.mark.asyncio
301343
async def test_session_memory_different_sessions_parametrized(runner_method):

0 commit comments

Comments
 (0)