Skip to content

Commit 8833b08

Browse files
authored
Update advanced_sqlite_session.py
1 parent 76f419f commit 8833b08

1 file changed

Lines changed: 19 additions & 16 deletions

File tree

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def __init__(
4949
self._init_structure_tables()
5050
self._current_branch_id = "main"
5151
self._logger = logger or logging.getLogger(__name__)
52+
# Create a dedicated lock for disk-based databases to ensure thread safety
53+
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock
54+
self._file_db_lock = threading.Lock()
5255

5356
def _init_structure_tables(self):
5457
"""Add structure and usage tracking tables.
@@ -158,8 +161,8 @@ async def get_items(
158161
def _get_all_items_sync():
159162
"""Synchronous helper to get all items for a branch."""
160163
conn = self._get_connection()
161-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
162-
with self._lock if self._is_memory_db else threading.Lock():
164+
# Use the instance lock for disk-based DBs to ensure consistent locking
165+
with self._lock if self._is_memory_db else self._file_db_lock:
163166
with closing(conn.cursor()) as cursor:
164167
if session_limit is None:
165168
cursor.execute(
@@ -203,8 +206,8 @@ def _get_all_items_sync():
203206
def _get_items_sync():
204207
"""Synchronous helper to get items for a specific branch."""
205208
conn = self._get_connection()
206-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
207-
with self._lock if self._is_memory_db else threading.Lock():
209+
# Use the instance lock for disk-based DBs to ensure consistent locking
210+
with self._lock if self._is_memory_db else self._file_db_lock:
208211
with closing(conn.cursor()) as cursor:
209212
# Get message IDs in correct order for this branch
210213
if session_limit is None:
@@ -345,8 +348,8 @@ async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None
345348
def _add_structure_sync():
346349
"""Synchronous helper to add structure metadata to database."""
347350
conn = self._get_connection()
348-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
349-
with self._lock if self._is_memory_db else threading.Lock():
351+
# Use the instance lock for disk-based DBs to ensure consistent locking
352+
with self._lock if self._is_memory_db else self._file_db_lock:
350353
# Get the IDs of messages we just inserted, in order
351354
with closing(conn.cursor()) as cursor:
352355
cursor.execute(
@@ -451,8 +454,8 @@ async def _cleanup_orphaned_messages(self) -> int:
451454
def _cleanup_sync():
452455
"""Synchronous helper to cleanup orphaned messages."""
453456
conn = self._get_connection()
454-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
455-
with self._lock if self._is_memory_db else threading.Lock():
457+
# Use the instance lock for disk-based DBs to ensure consistent locking
458+
with self._lock if self._is_memory_db else self._file_db_lock:
456459
with closing(conn.cursor()) as cursor:
457460
# Find messages without structure metadata
458461
cursor.execute(
@@ -722,8 +725,8 @@ async def delete_branch(self, branch_id: str, force: bool = False) -> None:
722725
def _delete_sync():
723726
"""Synchronous helper to delete branch and associated data."""
724727
conn = self._get_connection()
725-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
726-
with self._lock if self._is_memory_db else threading.Lock():
728+
# Use the instance lock for disk-based DBs to ensure consistent locking
729+
with self._lock if self._is_memory_db else self._file_db_lock:
727730
with closing(conn.cursor()) as cursor:
728731
# First verify the branch exists
729732
cursor.execute(
@@ -829,8 +832,8 @@ async def _copy_messages_to_new_branch(self, new_branch_id: str, from_turn_numbe
829832
def _copy_sync():
830833
"""Synchronous helper to copy messages to new branch."""
831834
conn = self._get_connection()
832-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
833-
with self._lock if self._is_memory_db else threading.Lock():
835+
# Use the instance lock for disk-based DBs to ensure consistent locking
836+
with self._lock if self._is_memory_db else self._file_db_lock:
834837
with closing(conn.cursor()) as cursor:
835838
# Get all messages before the branch point
836839
cursor.execute(
@@ -1124,8 +1127,8 @@ async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int
11241127
def _get_usage_sync():
11251128
"""Synchronous helper to get session usage data."""
11261129
conn = self._get_connection()
1127-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1128-
with self._lock if self._is_memory_db else threading.Lock():
1130+
# Use the instance lock for disk-based DBs to ensure consistent locking
1131+
with self._lock if self._is_memory_db else self._file_db_lock:
11291132
if branch_id:
11301133
# Branch-specific usage
11311134
query = """
@@ -1288,8 +1291,8 @@ async def _update_turn_usage_internal(self, user_turn_number: int, usage_data: U
12881291
def _update_sync():
12891292
"""Synchronous helper to update turn usage data."""
12901293
conn = self._get_connection()
1291-
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
1292-
with self._lock if self._is_memory_db else threading.Lock():
1294+
# Use the instance lock for disk-based DBs to ensure consistent locking
1295+
with self._lock if self._is_memory_db else self._file_db_lock:
12931296
# Serialize token details as JSON
12941297
input_details_json = None
12951298
output_details_json = None

0 commit comments

Comments
 (0)