Skip to content

Commit a17fc3f

Browse files
gustavzseratch
andauthored
Add AsyncSQLiteSession (aiosqlite-backed async session store) (#2284)
Co-authored-by: Kazuhiro Sera <seratch@openai.com>
1 parent 09443fd commit a17fc3f

4 files changed

Lines changed: 553 additions & 4 deletions

File tree

src/agents/extensions/memory/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
if TYPE_CHECKING:
1414
from .advanced_sqlite_session import AdvancedSQLiteSession
15+
from .async_sqlite_session import AsyncSQLiteSession
1516
from .dapr_session import (
1617
DAPR_CONSISTENCY_EVENTUAL,
1718
DAPR_CONSISTENCY_STRONG,
@@ -23,6 +24,7 @@
2324

2425
__all__: list[str] = [
2526
"AdvancedSQLiteSession",
27+
"AsyncSQLiteSession",
2628
"DAPR_CONSISTENCY_EVENTUAL",
2729
"DAPR_CONSISTENCY_STRONG",
2830
"DaprSession",
@@ -74,6 +76,14 @@ def __getattr__(name: str) -> Any:
7476
except ModuleNotFoundError as e:
7577
raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e
7678

79+
if name == "AsyncSQLiteSession":
80+
try:
81+
from .async_sqlite_session import AsyncSQLiteSession # noqa: F401
82+
83+
return AsyncSQLiteSession
84+
except ModuleNotFoundError as e:
85+
raise ImportError(f"Failed to import AsyncSQLiteSession: {e}") from e
86+
7787
if name == "DaprSession":
7888
try:
7989
from .dapr_session import DaprSession # noqa: F401
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import json
5+
from collections.abc import AsyncIterator
6+
from contextlib import asynccontextmanager
7+
from pathlib import Path
8+
from typing import cast
9+
10+
import aiosqlite
11+
12+
from ...items import TResponseInputItem
13+
from ...memory import SessionABC
14+
15+
16+
class AsyncSQLiteSession(SessionABC):
17+
"""Async SQLite-based implementation of session storage.
18+
19+
This implementation stores conversation history in a SQLite database.
20+
By default, uses an in-memory database that is lost when the process ends.
21+
For persistent storage, provide a file path.
22+
"""
23+
24+
def __init__(
25+
self,
26+
session_id: str,
27+
db_path: str | Path = ":memory:",
28+
sessions_table: str = "agent_sessions",
29+
messages_table: str = "agent_messages",
30+
):
31+
"""Initialize the async SQLite session.
32+
33+
Args:
34+
session_id: Unique identifier for the conversation session
35+
db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
36+
sessions_table: Name of the table to store session metadata. Defaults to
37+
'agent_sessions'
38+
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
39+
"""
40+
self.session_id = session_id
41+
self.db_path = db_path
42+
self.sessions_table = sessions_table
43+
self.messages_table = messages_table
44+
self._connection: aiosqlite.Connection | None = None
45+
self._lock = asyncio.Lock()
46+
self._init_lock = asyncio.Lock()
47+
48+
async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None:
49+
"""Initialize the database schema for a specific connection."""
50+
await conn.execute(
51+
f"""
52+
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
53+
session_id TEXT PRIMARY KEY,
54+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
55+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
56+
)
57+
"""
58+
)
59+
60+
await conn.execute(
61+
f"""
62+
CREATE TABLE IF NOT EXISTS {self.messages_table} (
63+
id INTEGER PRIMARY KEY AUTOINCREMENT,
64+
session_id TEXT NOT NULL,
65+
message_data TEXT NOT NULL,
66+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
67+
FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
68+
ON DELETE CASCADE
69+
)
70+
"""
71+
)
72+
73+
await conn.execute(
74+
f"""
75+
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
76+
ON {self.messages_table} (session_id, id)
77+
"""
78+
)
79+
80+
await conn.commit()
81+
82+
async def _get_connection(self) -> aiosqlite.Connection:
83+
"""Get or create a database connection."""
84+
if self._connection is not None:
85+
return self._connection
86+
87+
async with self._init_lock:
88+
if self._connection is None:
89+
self._connection = await aiosqlite.connect(str(self.db_path))
90+
await self._connection.execute("PRAGMA journal_mode=WAL")
91+
await self._init_db_for_connection(self._connection)
92+
93+
return self._connection
94+
95+
@asynccontextmanager
96+
async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
97+
"""Provide a connection under the session lock."""
98+
async with self._lock:
99+
conn = await self._get_connection()
100+
yield conn
101+
102+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
103+
"""Retrieve the conversation history for this session.
104+
105+
Args:
106+
limit: Maximum number of items to retrieve. If None, retrieves all items.
107+
When specified, returns the latest N items in chronological order.
108+
109+
Returns:
110+
List of input items representing the conversation history
111+
"""
112+
113+
async with self._locked_connection() as conn:
114+
if limit is None:
115+
cursor = await conn.execute(
116+
f"""
117+
SELECT message_data FROM {self.messages_table}
118+
WHERE session_id = ?
119+
ORDER BY id ASC
120+
""",
121+
(self.session_id,),
122+
)
123+
else:
124+
cursor = await conn.execute(
125+
f"""
126+
SELECT message_data FROM {self.messages_table}
127+
WHERE session_id = ?
128+
ORDER BY id DESC
129+
LIMIT ?
130+
""",
131+
(self.session_id, limit),
132+
)
133+
134+
rows = list(await cursor.fetchall())
135+
await cursor.close()
136+
137+
if limit is not None:
138+
rows = rows[::-1]
139+
140+
items: list[TResponseInputItem] = []
141+
for (message_data,) in rows:
142+
try:
143+
item = json.loads(message_data)
144+
items.append(item)
145+
except json.JSONDecodeError:
146+
continue
147+
148+
return items
149+
150+
async def add_items(self, items: list[TResponseInputItem]) -> None:
151+
"""Add new items to the conversation history.
152+
153+
Args:
154+
items: List of input items to add to the history
155+
"""
156+
if not items:
157+
return
158+
159+
async with self._locked_connection() as conn:
160+
await conn.execute(
161+
f"""
162+
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
163+
""",
164+
(self.session_id,),
165+
)
166+
167+
message_data = [(self.session_id, json.dumps(item)) for item in items]
168+
await conn.executemany(
169+
f"""
170+
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
171+
""",
172+
message_data,
173+
)
174+
175+
await conn.execute(
176+
f"""
177+
UPDATE {self.sessions_table}
178+
SET updated_at = CURRENT_TIMESTAMP
179+
WHERE session_id = ?
180+
""",
181+
(self.session_id,),
182+
)
183+
184+
await conn.commit()
185+
186+
async def pop_item(self) -> TResponseInputItem | None:
187+
"""Remove and return the most recent item from the session.
188+
189+
Returns:
190+
The most recent item if it exists, None if the session is empty
191+
"""
192+
async with self._locked_connection() as conn:
193+
cursor = await conn.execute(
194+
f"""
195+
DELETE FROM {self.messages_table}
196+
WHERE id = (
197+
SELECT id FROM {self.messages_table}
198+
WHERE session_id = ?
199+
ORDER BY id DESC
200+
LIMIT 1
201+
)
202+
RETURNING message_data
203+
""",
204+
(self.session_id,),
205+
)
206+
207+
result = await cursor.fetchone()
208+
await cursor.close()
209+
await conn.commit()
210+
211+
if result:
212+
message_data = result[0]
213+
try:
214+
return cast(TResponseInputItem, json.loads(message_data))
215+
except json.JSONDecodeError:
216+
return None
217+
218+
return None
219+
220+
async def clear_session(self) -> None:
221+
"""Clear all items for this session."""
222+
async with self._locked_connection() as conn:
223+
await conn.execute(
224+
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
225+
(self.session_id,),
226+
)
227+
await conn.execute(
228+
f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
229+
(self.session_id,),
230+
)
231+
await conn.commit()
232+
233+
async def close(self) -> None:
234+
"""Close the database connection."""
235+
if self._connection is None:
236+
return
237+
async with self._lock:
238+
await self._connection.close()
239+
self._connection = None

src/agents/memory/sqlite_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
101101
conn.execute(
102102
f"""
103103
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
104-
ON {self.messages_table} (session_id, created_at)
104+
ON {self.messages_table} (session_id, id)
105105
"""
106106
)
107107

@@ -127,7 +127,7 @@ def _get_items_sync():
127127
f"""
128128
SELECT message_data FROM {self.messages_table}
129129
WHERE session_id = ?
130-
ORDER BY created_at ASC
130+
ORDER BY id ASC
131131
""",
132132
(self.session_id,),
133133
)
@@ -137,7 +137,7 @@ def _get_items_sync():
137137
f"""
138138
SELECT message_data FROM {self.messages_table}
139139
WHERE session_id = ?
140-
ORDER BY created_at DESC
140+
ORDER BY id DESC
141141
LIMIT ?
142142
""",
143143
(self.session_id, limit),
@@ -223,7 +223,7 @@ def _pop_item_sync():
223223
WHERE id = (
224224
SELECT id FROM {self.messages_table}
225225
WHERE session_id = ?
226-
ORDER BY created_at DESC
226+
ORDER BY id DESC
227227
LIMIT 1
228228
)
229229
RETURNING message_data

0 commit comments

Comments
 (0)