Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 199 additions & 3 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import uuid
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
Expand Down Expand Up @@ -30,6 +31,8 @@ def __init__(
db_path: str | Path = ":memory:",
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
users_table: str = "agent_users",
user_id: str | None = None,
):
"""Initialize the async SQLite session.

Expand All @@ -39,27 +42,52 @@ def __init__(
sessions_table: Name of the table to store session metadata. Defaults to
'agent_sessions'
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
users_table: Name of the table to store user metadata. Defaults to 'agent_users'
user_id: Optional user identifier to associate this session with a user.
"""
self.session_id = session_id
self.user_id = user_id
self.db_path = db_path
self.sessions_table = sessions_table
self.messages_table = messages_table
self.users_table = users_table
self._connection: aiosqlite.Connection | None = None
self._lock = asyncio.Lock()
self._init_lock = asyncio.Lock()

async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None:
"""Initialize the database schema for a specific connection."""
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.users_table} (
user_id TEXT PRIMARY KEY,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)

await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
session_id TEXT PRIMARY KEY,
user_id TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id)
ON DELETE SET NULL
)
"""
)

await conn.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id
ON {self.sessions_table} (user_id)
"""
)

await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.messages_table} (
Expand Down Expand Up @@ -160,11 +188,21 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
return

async with self._locked_connection() as conn:
# Ensure user exists if user_id is provided
if self.user_id is not None:
await conn.execute(
f"""
INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?)
""",
(self.user_id,),
)

await conn.execute(
f"""
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id)
VALUES (?, ?)
""",
(self.session_id,),
(self.session_id, self.user_id),
)

message_data = [(self.session_id, json.dumps(item)) for item in items]
Expand Down Expand Up @@ -233,6 +271,164 @@ async def clear_session(self) -> None:
)
await conn.commit()

@classmethod
async def create_session(
cls,
user_id: str,
db_path: str | Path = ":memory:",
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
users_table: str = "agent_users",
) -> AsyncSQLiteSession:
"""Create a new session for a user with an auto-generated session ID.

Args:
user_id: The user identifier to associate with the new session.
db_path: Path to the SQLite database file. Defaults to ':memory:'.
sessions_table: Name of the sessions table. Defaults to 'agent_sessions'.
messages_table: Name of the messages table. Defaults to 'agent_messages'.
users_table: Name of the users table. Defaults to 'agent_users'.

Returns:
A new AsyncSQLiteSession instance with an auto-generated session_id.
"""
session_id = str(uuid.uuid4())
session = cls(
session_id=session_id,
db_path=db_path,
sessions_table=sessions_table,
messages_table=messages_table,
users_table=users_table,
user_id=user_id,
)

async with session._locked_connection() as conn:
await conn.execute(
f"INSERT OR IGNORE INTO {users_table} (user_id) VALUES (?)",
(user_id,),
)
await conn.execute(
f"INSERT INTO {sessions_table} (session_id, user_id) VALUES (?, ?)",
(session_id, user_id),
)
await conn.commit()

return session

@classmethod
async def get_sessions(
cls,
user_id: str,
db_path: str | Path = ":memory:",
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
users_table: str = "agent_users",
limit: int | None = None,
offset: int = 0,
) -> list[AsyncSQLiteSession]:
"""Retrieve all sessions for a user.

Args:
user_id: The user identifier to look up sessions for.
db_path: Path to the SQLite database file. Defaults to ':memory:'.
sessions_table: Name of the sessions table. Defaults to 'agent_sessions'.
messages_table: Name of the messages table. Defaults to 'agent_messages'.
users_table: Name of the users table. Defaults to 'agent_users'.
limit: Maximum number of sessions to return. If None, returns all sessions.
offset: Number of sessions to skip before returning results. Defaults to 0.

Returns:
List of AsyncSQLiteSession instances belonging to the user, ordered by
most recently updated first.
"""
probe = cls(
session_id="__probe__",
db_path=db_path,
sessions_table=sessions_table,
messages_table=messages_table,
users_table=users_table,
user_id=user_id,
)

async with probe._locked_connection() as conn:
if limit is None:
cursor = await conn.execute(
f"""
SELECT session_id FROM {sessions_table}
WHERE user_id = ?
ORDER BY updated_at DESC
LIMIT -1 OFFSET ?
""",
(user_id, offset),
)
else:
cursor = await conn.execute(
f"""
SELECT session_id FROM {sessions_table}
WHERE user_id = ?
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(user_id, limit, offset),
)
rows = await cursor.fetchall()
await cursor.close()

await probe.close()

return [
cls(
session_id=row[0],
db_path=db_path,
sessions_table=sessions_table,
messages_table=messages_table,
users_table=users_table,
user_id=user_id,
)
for row in rows
]

async def get_sessions_for_user(
self,
user_id: str,
limit: int | None = None,
offset: int = 0,
) -> list[str]:
"""Retrieve session IDs associated with a given user.

Args:
user_id: The user identifier to look up sessions for.
limit: Maximum number of session IDs to return. If None, returns all sessions.
offset: Number of sessions to skip before returning results. Defaults to 0.

Returns:
List of session IDs belonging to the user, ordered by most recently updated first.
"""
async with self._locked_connection() as conn:
if limit is None:
cursor = await conn.execute(
f"""
SELECT session_id FROM {self.sessions_table}
WHERE user_id = ?
ORDER BY updated_at DESC
LIMIT -1 OFFSET ?
""",
(user_id, offset),
)
else:
cursor = await conn.execute(
f"""
SELECT session_id FROM {self.sessions_table}
WHERE user_id = ?
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(user_id, limit, offset),
)
rows = await cursor.fetchall()
await cursor.close()
return [row[0] for row in rows]

async def close(self) -> None:
"""Close the database connection."""
if self._connection is None:
Expand Down
Loading