Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ...items import TResponseInputItem
from ...memory import SessionABC
from ...memory.session_settings import SessionSettings
from ...memory.session_settings import SessionSettings, resolve_session_limit


class AsyncSQLiteSession(SessionABC):
Expand All @@ -30,6 +30,7 @@ def __init__(
db_path: str | Path = ":memory:",
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
session_settings: SessionSettings | None = None,
):
"""Initialize the async SQLite session.

Expand All @@ -39,8 +40,11 @@ def __init__(
sessions_table: Name of the table to store session metadata. Defaults to
'agent_sessions'
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
session_settings: Session configuration settings including default limit for
retrieving items. If None, uses default SessionSettings().
"""
self.session_id = session_id
self.session_settings = session_settings or SessionSettings()
self.db_path = db_path
self.sessions_table = sessions_table
self.messages_table = messages_table
Expand Down Expand Up @@ -106,15 +110,17 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
limit: Maximum number of items to retrieve. If None, retrieves all items.
limit: Maximum number of items to retrieve. If None, uses session_settings.limit.
When specified, returns the latest N items in chronological order.

Returns:
List of input items representing the conversation history
"""

session_limit = resolve_session_limit(limit, self.session_settings)

async with self._locked_connection() as conn:
if limit is None:
if session_limit is None:
cursor = await conn.execute(
f"""
SELECT message_data FROM {self.messages_table}
Expand All @@ -131,13 +137,13 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
ORDER BY id DESC
LIMIT ?
""",
(self.session_id, limit),
(self.session_id, session_limit),
)

rows = list(await cursor.fetchall())
await cursor.close()

if limit is not None:
if session_limit is not None:
rows = rows[::-1]

items: list[TResponseInputItem] = []
Expand Down
69 changes: 69 additions & 0 deletions tests/extensions/memory/test_async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from agents import Agent, Runner, TResponseInputItem
from agents.extensions.memory import AsyncSQLiteSession
from agents.memory import SessionSettings
from tests.fake_model import FakeModel
from tests.test_responses import get_text_message

Expand Down Expand Up @@ -140,6 +141,74 @@ async def test_async_sqlite_session_get_items_limit():
await session.close()


async def test_async_sqlite_session_session_settings_default():
"""Test that session_settings defaults to empty SessionSettings."""
session = AsyncSQLiteSession("async_default_settings")

assert isinstance(session.session_settings, SessionSettings)
assert session.session_settings.limit is None

await session.close()


async def test_async_sqlite_session_session_settings_constructor():
"""Test passing session_settings via constructor."""
session = AsyncSQLiteSession(
"async_constructor_settings",
session_settings=SessionSettings(limit=5),
)

assert session.session_settings is not None
assert session.session_settings.limit == 5

await session.close()


async def test_async_sqlite_session_get_items_uses_session_settings_limit():
"""Test that get_items uses session_settings.limit as default."""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "async_settings_limit.db"
session = AsyncSQLiteSession(
"async_settings_limit",
db_path,
session_settings=SessionSettings(limit=3),
)

items: list[TResponseInputItem] = [
{"role": "user", "content": f"Message {i}"} for i in range(5)
]
await session.add_items(items)

retrieved = await session.get_items()
assert retrieved == items[-3:]

await session.close()


async def test_async_sqlite_session_explicit_limit_overrides_session_settings():
"""Test that explicit limit parameter overrides session_settings."""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "async_settings_override.db"
session = AsyncSQLiteSession(
"async_settings_override",
db_path,
session_settings=SessionSettings(limit=5),
)

items: list[TResponseInputItem] = [
{"role": "user", "content": f"Message {i}"} for i in range(10)
]
await session.add_items(items)

retrieved = await session.get_items(limit=2)
assert retrieved == items[-2:]

no_items = await session.get_items(limit=0)
assert no_items == []

await session.close()


async def test_async_sqlite_session_unicode_content():
"""Test AsyncSQLiteSession stores unicode content."""
session = AsyncSQLiteSession("async_unicode")
Expand Down
Loading