diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 2eef596264..770ca771a1 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -2,6 +2,7 @@ import asyncio import json +import uuid from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path @@ -30,6 +31,8 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, ): """Initialize the async SQLite session. @@ -39,27 +42,52 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + users_table: Name of the table to store user metadata. Defaults to 'agent_users' + user_id: Optional user identifier to associate this session with a user. """ self.session_id = session_id + self.user_id = user_id self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.users_table = users_table self._connection: aiosqlite.Connection | None = None self._lock = asyncio.Lock() self._init_lock = asyncio.Lock() async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None: """Initialize the database schema for a specific connection.""" + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.users_table} ( + user_id TEXT PRIMARY KEY, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + await conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, + user_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id) + ON DELETE SET NULL ) """ ) + await conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id + ON {self.sessions_table} (user_id) + """ + ) + await conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.messages_table} ( @@ -160,11 +188,21 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: return async with self._locked_connection() as conn: + # Ensure user exists if user_id is provided + if self.user_id is not None: + await conn.execute( + f""" + INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?) + """, + (self.user_id,), + ) + await conn.execute( f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id) + VALUES (?, ?) """, - (self.session_id,), + (self.session_id, self.user_id), ) message_data = [(self.session_id, json.dumps(item)) for item in items] @@ -233,6 +271,164 @@ async def clear_session(self) -> None: ) await conn.commit() + @classmethod + async def create_session( + cls, + user_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + ) -> AsyncSQLiteSession: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + + Returns: + A new AsyncSQLiteSession instance with an auto-generated session_id. + """ + session_id = str(uuid.uuid4()) + session = cls( + session_id=session_id, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + ) + + async with session._locked_connection() as conn: + await conn.execute( + f"INSERT OR IGNORE INTO {users_table} (user_id) VALUES (?)", + (user_id,), + ) + await conn.execute( + f"INSERT INTO {sessions_table} (session_id, user_id) VALUES (?, ?)", + (session_id, user_id), + ) + await conn.commit() + + return session + + @classmethod + async def get_sessions( + cls, + user_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + limit: int | None = None, + offset: int = 0, + ) -> list[AsyncSQLiteSession]: + """Retrieve all sessions for a user. + + Args: + user_id: The user identifier to look up sessions for. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + limit: Maximum number of sessions to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. + + Returns: + List of AsyncSQLiteSession instances belonging to the user, ordered by + most recently updated first. + """ + probe = cls( + session_id="__probe__", + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + ) + + async with probe._locked_connection() as conn: + if limit is None: + cursor = await conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = await conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) + rows = await cursor.fetchall() + await cursor.close() + + await probe.close() + + return [ + cls( + session_id=row[0], + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + ) + for row in rows + ] + + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + async with self._locked_connection() as conn: + if limit is None: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) + rows = await cursor.fetchall() + await cursor.close() + return [row[0] for row in rows] + async def close(self) -> None: """Close the database connection.""" if self._connection is None: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 759ddaf5d5..c7eabd5cea 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -26,6 +26,7 @@ import asyncio import json import threading +import uuid from typing import Any, ClassVar from sqlalchemy import ( @@ -86,6 +87,8 @@ def __init__( create_tables: bool = False, sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, session_settings: SessionSettings | None = None, ): """Initializes a new SQLAlchemySession. @@ -100,9 +103,13 @@ def __init__( development and testing when migrations aren't used. sessions_table (str, optional): Override the default table name for sessions if needed. messages_table (str, optional): Override the default table name for messages if needed. + users_table (str, optional): Override the default table name for users if needed. + user_id (str | None, optional): Optional user identifier to associate this session + with a user in the agent_users table. session_settings (SessionSettings | None, optional): Session configuration settings """ self.session_id = session_id + self.user_id = user_id self.session_settings = session_settings or SessionSettings() self._engine = engine self._init_lock = ( @@ -112,10 +119,36 @@ def __init__( ) self._metadata = MetaData() + self._users = Table( + users_table, + self._metadata, + Column("user_id", String, primary_key=True), + Column("metadata", Text, nullable=True), + Column( + "created_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + Column( + "updated_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + onupdate=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + self._sessions = Table( sessions_table, self._metadata, Column("session_id", String, primary_key=True), + Column( + "user_id", + String, + ForeignKey(f"{users_table}.user_id", ondelete="SET NULL"), + nullable=True, + ), Column( "created_at", TIMESTAMP(timezone=False), @@ -129,6 +162,7 @@ def __init__( onupdate=sql_text("CURRENT_TIMESTAMP"), nullable=False, ), + Index(f"idx_{sessions_table}_user_id", "user_id"), ) self._messages = Table( @@ -296,6 +330,22 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: async with self._session_factory() as sess: async with sess.begin(): + # Ensure user exists if user_id is provided + if self.user_id is not None: + existing_user = await sess.execute( + select(self._users.c.user_id).where( + self._users.c.user_id == self.user_id + ) + ) + if not existing_user.scalar_one_or_none(): + try: + async with sess.begin_nested(): + await sess.execute( + insert(self._users).values({"user_id": self.user_id}) + ) + except IntegrityError: + pass + # Avoid check-then-insert races on the first write while keeping # the common path free of avoidable integrity exceptions. existing = await sess.execute( @@ -307,7 +357,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: try: async with sess.begin_nested(): await sess.execute( - insert(self._sessions).values({"session_id": self.session_id}) + insert(self._sessions).values( + {"session_id": self.session_id, "user_id": self.user_id} + ) ) except IntegrityError: # Another concurrent writer created the parent row first. @@ -372,6 +424,166 @@ async def clear_session(self) -> None: delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) + @classmethod + async def create_session( + cls, + user_id: str, + *, + engine: AsyncEngine, + create_tables: bool = False, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + ) -> SQLAlchemySession: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + engine: A pre-configured SQLAlchemy async engine. + create_tables: Whether to auto-create tables. Defaults to False. + sessions_table: Override the default table name for sessions. + messages_table: Override the default table name for messages. + users_table: Override the default table name for users. + session_settings: Session configuration settings. + + Returns: + A new SQLAlchemySession instance with an auto-generated session_id. + """ + session_id = str(uuid.uuid4()) + session = cls( + session_id=session_id, + engine=engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + await session._ensure_tables() + async with session._session_factory() as sess: + async with sess.begin(): + existing_user = await sess.execute( + select(session._users.c.user_id).where( + session._users.c.user_id == user_id + ) + ) + if not existing_user.scalar_one_or_none(): + try: + async with sess.begin_nested(): + await sess.execute( + insert(session._users).values({"user_id": user_id}) + ) + except IntegrityError: + pass + await sess.execute( + insert(session._sessions).values( + {"session_id": session_id, "user_id": user_id} + ) + ) + + return session + + @classmethod + async def get_sessions( + cls, + user_id: str, + *, + engine: AsyncEngine, + create_tables: bool = False, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + limit: int | None = None, + offset: int = 0, + ) -> list[SQLAlchemySession]: + """Retrieve all sessions for a user. + + Args: + user_id: The user identifier to look up sessions for. + engine: A pre-configured SQLAlchemy async engine. + create_tables: Whether to auto-create tables. Defaults to False. + sessions_table: Override the default table name for sessions. + messages_table: Override the default table name for messages. + users_table: Override the default table name for users. + session_settings: Session configuration settings. + limit: Maximum number of sessions to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. + + Returns: + List of SQLAlchemySession instances belonging to the user, ordered by + most recently updated first. + """ + probe = cls( + session_id="__probe__", + engine=engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + await probe._ensure_tables() + async with probe._session_factory() as sess: + stmt = ( + select(probe._sessions.c.session_id) + .where(probe._sessions.c.user_id == user_id) + .order_by(probe._sessions.c.updated_at.desc()) + .offset(offset) + ) + if limit is not None: + stmt = stmt.limit(limit) + result = await sess.execute(stmt) + session_ids = [row[0] for row in result.all()] + + return [ + cls( + session_id=sid, + engine=engine, + create_tables=False, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + for sid in session_ids + ] + + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select(self._sessions.c.session_id) + .where(self._sessions.c.user_id == user_id) + .order_by(self._sessions.c.updated_at.desc()) + .offset(offset) + ) + if limit is not None: + stmt = stmt.limit(limit) + result = await sess.execute(stmt) + return [row[0] for row in result.all()] + @property def engine(self) -> AsyncEngine: """Access the underlying SQLAlchemy AsyncEngine. diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 85a65a1690..539f658f84 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -19,6 +19,7 @@ class Session(Protocol): """ session_id: str + user_id: str | None = None session_settings: SessionSettings | None = None async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: @@ -65,6 +66,7 @@ class SessionABC(ABC): """ session_id: str + user_id: str | None = None session_settings: SessionSettings | None = None @abstractmethod diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..d8b329ce02 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -4,6 +4,7 @@ import json import sqlite3 import threading +import uuid from pathlib import Path from ..items import TResponseInputItem @@ -27,6 +28,8 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, session_settings: SessionSettings | None = None, ): """Initialize the SQLite session. @@ -37,14 +40,19 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + users_table: Name of the table to store user metadata. Defaults to 'agent_users' + user_id: Optional user identifier to associate this session with a user. + When provided, the session will be linked to the user in the agent_users table. session_settings: Session configuration settings including default limit for retrieving items. If None, uses default SessionSettings(). """ self.session_id = session_id + self.user_id = user_id self.session_settings = session_settings or SessionSettings() self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.users_table = users_table self._local = threading.local() self._lock = threading.Lock() @@ -82,16 +90,37 @@ def _get_connection(self) -> sqlite3.Connection: def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.users_table} ( + user_id TEXT PRIMARY KEY, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, + user_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id) + ON DELETE SET NULL ) """ ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id + ON {self.sessions_table} (user_id) + """ + ) + conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.messages_table} ( @@ -183,12 +212,22 @@ def _add_items_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): + # Ensure user exists if user_id is provided + if self.user_id is not None: + conn.execute( + f""" + INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?) + """, + (self.user_id,), + ) + # Ensure session exists conn.execute( f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id) + VALUES (?, ?) """, - (self.session_id,), + (self.session_id, self.user_id), ) # Add items @@ -273,6 +312,180 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) + @classmethod + async def create_session( + cls, + user_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + ) -> SQLiteSession: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + session_settings: Session configuration settings. + + Returns: + A new SQLiteSession instance with an auto-generated session_id. + """ + session_id = str(uuid.uuid4()) + session = cls( + session_id=session_id, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + def _persist_session(): + conn = session._get_connection() + with session._lock if session._is_memory_db else threading.Lock(): + conn.execute( + f"INSERT OR IGNORE INTO {users_table} (user_id) VALUES (?)", + (user_id,), + ) + conn.execute( + f"INSERT INTO {sessions_table} (session_id, user_id) VALUES (?, ?)", + (session_id, user_id), + ) + conn.commit() + + await asyncio.to_thread(_persist_session) + return session + + @classmethod + async def get_sessions( + cls, + user_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + limit: int | None = None, + offset: int = 0, + ) -> list[SQLiteSession]: + """Retrieve all sessions for a user. + + Args: + user_id: The user identifier to look up sessions for. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + session_settings: Session configuration settings. + limit: Maximum number of sessions to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. + + Returns: + List of SQLiteSession instances belonging to the user, ordered by most + recently updated first. + """ + # Use a temporary instance to access the DB and query session IDs + probe = cls( + session_id="__probe__", + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + def _fetch_ids(): + conn = probe._get_connection() + with probe._lock if probe._is_memory_db else threading.Lock(): + if limit is None: + cursor = conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) + return [row[0] for row in cursor.fetchall()] + + session_ids = await asyncio.to_thread(_fetch_ids) + probe.close() + + return [ + cls( + session_id=sid, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + for sid in session_ids + ] + + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + + def _get_sessions_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + if limit is None: + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) + return [row[0] for row in cursor.fetchall()] + + return await asyncio.to_thread(_get_sessions_sync) + def close(self) -> None: """Close the database connection.""" if self._is_memory_db: diff --git a/tests/test_session.py b/tests/test_session.py index aaa80ec7aa..8d4c90a9da 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -671,6 +671,219 @@ async def test_session_settings_resolve(): assert final_none.limit == 100 +@pytest.mark.asyncio +async def test_sqlite_session_user_association(): + """Test that sessions can be associated with users via user_id.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_users.db" + + # Create sessions for user_1 + session_a = SQLiteSession("session_a", db_path, user_id="user_1") + session_b = SQLiteSession("session_b", db_path, user_id="user_1") + # Create a session for user_2 + session_c = SQLiteSession("session_c", db_path, user_id="user_2") + # Create a session without a user + session_d = SQLiteSession("session_d", db_path) + + # Add items to trigger session/user creation + items: list[TResponseInputItem] = [{"role": "user", "content": "Hello"}] + await session_a.add_items(items) + await session_b.add_items(items) + await session_c.add_items(items) + await session_d.add_items(items) + + # Query sessions for user_1 + user_1_sessions = await session_a.get_sessions_for_user("user_1") + assert set(user_1_sessions) == {"session_a", "session_b"} + + # Query sessions for user_2 + user_2_sessions = await session_a.get_sessions_for_user("user_2") + assert user_2_sessions == ["session_c"] + + # Query sessions for non-existent user + empty_sessions = await session_a.get_sessions_for_user("user_999") + assert empty_sessions == [] + + session_a.close() + session_b.close() + session_c.close() + session_d.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_get_sessions_for_user_pagination(): + """Test limit and offset pagination on get_sessions_for_user.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pagination.db" + items: list[TResponseInputItem] = [{"role": "user", "content": "hi"}] + + # Create 5 sessions for the same user, adding items sequentially so + # updated_at ordering is deterministic (most recent last). + session_ids = [f"s{i}" for i in range(5)] + sessions = [] + for sid in session_ids: + s = SQLiteSession(sid, db_path, user_id="paginated_user") + await s.add_items(items) + sessions.append(s) + + ref = sessions[0] # any session instance sharing the same db + + # Without limit — returns all 5 + all_ids = await ref.get_sessions_for_user("paginated_user") + assert len(all_ids) == 5 + + # limit=2 — returns the 2 most recently updated + page1 = await ref.get_sessions_for_user("paginated_user", limit=2) + assert len(page1) == 2 + + # limit=2, offset=2 — next page + page2 = await ref.get_sessions_for_user("paginated_user", limit=2, offset=2) + assert len(page2) == 2 + + # limit=2, offset=4 — last page (only 1 left) + page3 = await ref.get_sessions_for_user("paginated_user", limit=2, offset=4) + assert len(page3) == 1 + + # All pages together should cover all session ids + assert set(page1 + page2 + page3) == set(session_ids) + + # offset beyond total — empty + empty = await ref.get_sessions_for_user("paginated_user", offset=10) + assert empty == [] + + for s in sessions: + s.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_user_id_attribute(): + """Test that user_id is correctly stored on the session instance.""" + session_with_user = SQLiteSession("s1", user_id="alice") + assert session_with_user.user_id == "alice" + + session_without_user = SQLiteSession("s2") + assert session_without_user.user_id is None + + session_with_user.close() + session_without_user.close() + + +@pytest.mark.asyncio +async def test_sqlite_create_session(): + """Test create_session generates a UUID session_id and persists user and session.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_create.db" + + session = await SQLiteSession.create_session("alice", db_path=db_path) + + # session_id should be a valid UUID + import uuid + + uuid.UUID(session.session_id) # raises if not valid + assert session.user_id == "alice" + + # Session should be queryable via get_sessions_for_user + sessions = await session.get_sessions_for_user("alice") + assert session.session_id in sessions + + # Adding items should work normally + items: list[TResponseInputItem] = [{"role": "user", "content": "Hello"}] + await session.add_items(items) + retrieved = await session.get_items() + assert len(retrieved) == 1 + assert retrieved[0]["content"] == "Hello" + + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_get_sessions(): + """Test get_sessions retrieves all sessions for a user.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_get.db" + + # Create sessions for bob + s1 = await SQLiteSession.create_session("bob", db_path=db_path) + s2 = await SQLiteSession.create_session("bob", db_path=db_path) + # Create a session for eve + await SQLiteSession.create_session("eve", db_path=db_path) + + # Add items to s1 + items: list[TResponseInputItem] = [{"role": "user", "content": "Hi Bob"}] + await s1.add_items(items) + s1.close() + s2.close() + + # Retrieve bob's sessions + bob_sessions = await SQLiteSession.get_sessions("bob", db_path=db_path) + assert len(bob_sessions) == 2 + session_ids = {s.session_id for s in bob_sessions} + assert s1.session_id in session_ids + assert s2.session_id in session_ids + + # Each returned session should be usable + for s in bob_sessions: + assert s.user_id == "bob" + if s.session_id == s1.session_id: + history = await s.get_items() + assert len(history) == 1 + assert history[0]["content"] == "Hi Bob" + s.close() + + # Non-existent user returns empty list + empty = await SQLiteSession.get_sessions("nobody", db_path=db_path) + assert empty == [] + + +@pytest.mark.asyncio +async def test_sqlite_get_sessions_pagination(): + """Test get_sessions supports limit and offset.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pagination_get.db" + + created = [] + for _ in range(5): + s = await SQLiteSession.create_session("paguser", db_path=db_path) + created.append(s) + s.close() + + page1 = await SQLiteSession.get_sessions("paguser", db_path=db_path, limit=2) + assert len(page1) == 2 + + page2 = await SQLiteSession.get_sessions("paguser", db_path=db_path, limit=2, offset=2) + assert len(page2) == 2 + + page3 = await SQLiteSession.get_sessions("paguser", db_path=db_path, limit=2, offset=4) + assert len(page3) == 1 + + all_ids = {s.session_id for s in page1 + page2 + page3} + assert all_ids == {s.session_id for s in created} + + for s in page1 + page2 + page3: + s.close() + + +@pytest.mark.asyncio +async def test_sqlite_create_multiple_sessions_for_user(): + """Test creating multiple sessions for the same user.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_multi.db" + + s1 = await SQLiteSession.create_session("charlie", db_path=db_path) + s2 = await SQLiteSession.create_session("charlie", db_path=db_path) + + assert s1.session_id != s2.session_id + assert s1.user_id == s2.user_id == "charlie" + + sessions = await SQLiteSession.get_sessions("charlie", db_path=db_path) + assert {s.session_id for s in sessions} == {s1.session_id, s2.session_id} + + s1.close() + s2.close() + for s in sessions: + s.close() + + @pytest.mark.asyncio async def test_runner_with_session_settings_override(): """Test that RunConfig can override session's default settings."""