44import json
55import sqlite3
66import threading
7+ from collections .abc import Iterator
8+ from contextlib import contextmanager
79from pathlib import Path
10+ from typing import ClassVar
811
912from ..items import TResponseInputItem
1013from .session import SessionABC
@@ -20,6 +23,9 @@ class SQLiteSession(SessionABC):
2023 """
2124
2225 session_settings : SessionSettings | None = None
26+ _file_locks : ClassVar [dict [Path , threading .RLock ]] = {}
27+ _file_lock_counts : ClassVar [dict [Path , int ]] = {}
28+ _file_locks_guard : ClassVar [threading .Lock ] = threading .Lock ()
2329
2430 def __init__ (
2531 self ,
@@ -46,21 +52,66 @@ def __init__(
4652 self .sessions_table = sessions_table
4753 self .messages_table = messages_table
4854 self ._local = threading .local ()
49- self ._lock = threading .Lock ()
5055
5156 # For in-memory databases, we need a shared connection to avoid thread isolation
5257 # For file databases, we use thread-local connections for better concurrency
5358 self ._is_memory_db = str (db_path ) == ":memory:"
59+ self ._lock_path : Path | None = None
60+ self ._lock_released = False
5461 if self ._is_memory_db :
55- self ._shared_connection = sqlite3 .connect (":memory:" , check_same_thread = False )
56- self ._shared_connection .execute ("PRAGMA journal_mode=WAL" )
57- self ._init_db_for_connection (self ._shared_connection )
62+ self ._lock = threading .RLock ()
5863 else :
59- # For file databases, initialize the schema once since it persists
60- init_conn = sqlite3 .connect (str (self .db_path ), check_same_thread = False )
61- init_conn .execute ("PRAGMA journal_mode=WAL" )
62- self ._init_db_for_connection (init_conn )
63- init_conn .close ()
64+ self ._lock_path , self ._lock = self ._acquire_file_lock (Path (self .db_path ))
65+
66+ try :
67+ if self ._is_memory_db :
68+ self ._shared_connection = sqlite3 .connect (":memory:" , check_same_thread = False )
69+ self ._shared_connection .execute ("PRAGMA journal_mode=WAL" )
70+ self ._init_db_for_connection (self ._shared_connection )
71+ else :
72+ # For file databases, initialize the schema once since it persists
73+ with self ._lock :
74+ init_conn = sqlite3 .connect (str (self .db_path ), check_same_thread = False )
75+ init_conn .execute ("PRAGMA journal_mode=WAL" )
76+ self ._init_db_for_connection (init_conn )
77+ init_conn .close ()
78+ except Exception :
79+ if self ._lock_path is not None and not self ._lock_released :
80+ self ._release_file_lock (self ._lock_path )
81+ self ._lock_released = True
82+ raise
83+
84+ @classmethod
85+ def _acquire_file_lock (cls , db_path : Path ) -> tuple [Path , threading .RLock ]:
86+ """Return the path key and process-local lock for sessions sharing one SQLite file."""
87+ lock_path = db_path .expanduser ().resolve ()
88+ with cls ._file_locks_guard :
89+ lock = cls ._file_locks .get (lock_path )
90+ if lock is None :
91+ lock = threading .RLock ()
92+ cls ._file_locks [lock_path ] = lock
93+ cls ._file_lock_counts [lock_path ] = 0
94+ cls ._file_lock_counts [lock_path ] += 1
95+ return lock_path , lock
96+
97+ @classmethod
98+ def _release_file_lock (cls , lock_path : Path ) -> None :
99+ """Drop the shared lock for a file-backed DB once the last session closes."""
100+ with cls ._file_locks_guard :
101+ ref_count = cls ._file_lock_counts .get (lock_path )
102+ if ref_count is None :
103+ return
104+ if ref_count <= 1 :
105+ cls ._file_lock_counts .pop (lock_path , None )
106+ cls ._file_locks .pop (lock_path , None )
107+ else :
108+ cls ._file_lock_counts [lock_path ] = ref_count - 1
109+
110+ @contextmanager
111+ def _locked_connection (self ) -> Iterator [sqlite3 .Connection ]:
112+ """Serialize sqlite3 access while each operation runs in a worker thread."""
113+ with self ._lock :
114+ yield self ._get_connection ()
64115
65116 def _get_connection (self ) -> sqlite3 .Connection :
66117 """Get a database connection."""
@@ -114,6 +165,31 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
114165
115166 conn .commit ()
116167
168+ def _insert_items (self , conn : sqlite3 .Connection , items : list [TResponseInputItem ]) -> None :
169+ conn .execute (
170+ f"""
171+ INSERT OR IGNORE INTO { self .sessions_table } (session_id) VALUES (?)
172+ """ ,
173+ (self .session_id ,),
174+ )
175+
176+ message_data = [(self .session_id , json .dumps (item )) for item in items ]
177+ conn .executemany (
178+ f"""
179+ INSERT INTO { self .messages_table } (session_id, message_data) VALUES (?, ?)
180+ """ ,
181+ message_data ,
182+ )
183+
184+ conn .execute (
185+ f"""
186+ UPDATE { self .sessions_table }
187+ SET updated_at = CURRENT_TIMESTAMP
188+ WHERE session_id = ?
189+ """ ,
190+ (self .session_id ,),
191+ )
192+
117193 async def get_items (self , limit : int | None = None ) -> list [TResponseInputItem ]:
118194 """Retrieve the conversation history for this session.
119195
@@ -127,8 +203,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
127203 session_limit = resolve_session_limit (limit , self .session_settings )
128204
129205 def _get_items_sync ():
130- conn = self ._get_connection ()
131- with self ._lock if self ._is_memory_db else threading .Lock ():
206+ with self ._locked_connection () as conn :
132207 if session_limit is None :
133208 # Fetch all items in chronological order
134209 cursor = conn .execute (
@@ -180,36 +255,8 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
180255 return
181256
182257 def _add_items_sync ():
183- conn = self ._get_connection ()
184-
185- with self ._lock if self ._is_memory_db else threading .Lock ():
186- # Ensure session exists
187- conn .execute (
188- f"""
189- INSERT OR IGNORE INTO { self .sessions_table } (session_id) VALUES (?)
190- """ ,
191- (self .session_id ,),
192- )
193-
194- # Add items
195- message_data = [(self .session_id , json .dumps (item )) for item in items ]
196- conn .executemany (
197- f"""
198- INSERT INTO { self .messages_table } (session_id, message_data) VALUES (?, ?)
199- """ ,
200- message_data ,
201- )
202-
203- # Update session timestamp
204- conn .execute (
205- f"""
206- UPDATE { self .sessions_table }
207- SET updated_at = CURRENT_TIMESTAMP
208- WHERE session_id = ?
209- """ ,
210- (self .session_id ,),
211- )
212-
258+ with self ._locked_connection () as conn :
259+ self ._insert_items (conn , items )
213260 conn .commit ()
214261
215262 await asyncio .to_thread (_add_items_sync )
@@ -222,8 +269,7 @@ async def pop_item(self) -> TResponseInputItem | None:
222269 """
223270
224271 def _pop_item_sync ():
225- conn = self ._get_connection ()
226- with self ._lock if self ._is_memory_db else threading .Lock ():
272+ with self ._locked_connection () as conn :
227273 # Use DELETE with RETURNING to atomically delete and return the most recent item
228274 cursor = conn .execute (
229275 f"""
@@ -259,8 +305,7 @@ async def clear_session(self) -> None:
259305 """Clear all items for this session."""
260306
261307 def _clear_session_sync ():
262- conn = self ._get_connection ()
263- with self ._lock if self ._is_memory_db else threading .Lock ():
308+ with self ._locked_connection () as conn :
264309 conn .execute (
265310 f"DELETE FROM { self .messages_table } WHERE session_id = ?" ,
266311 (self .session_id ,),
@@ -281,3 +326,6 @@ def close(self) -> None:
281326 else :
282327 if hasattr (self ._local , "connection" ):
283328 self ._local .connection .close ()
329+ if self ._lock_path is not None and not self ._lock_released :
330+ self ._release_file_lock (self ._lock_path )
331+ self ._lock_released = True
0 commit comments