Skip to content

Commit 6071fdb

Browse files
committed
Fix AdvancedSQLiteSession add_items atomicity
1 parent 92e014a commit 6071fdb

2 files changed

Lines changed: 187 additions & 25 deletions

File tree

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -133,26 +133,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
133133
def _add_items_sync():
134134
"""Synchronous helper to add items and structure metadata together."""
135135
with self._locked_connection() as conn:
136-
# Keep both writes in one critical section so message IDs and metadata stay aligned.
137-
self._insert_items(conn, items)
138-
conn.commit()
139136
try:
137+
# Keep both writes in one transaction so metadata failures do not leave orphans.
138+
self._insert_items(conn, items)
140139
self._insert_structure_metadata(conn, items)
141140
conn.commit()
142-
except Exception as e:
141+
except Exception:
143142
conn.rollback()
144-
self._logger.error(
145-
f"Failed to add structure metadata for session {self.session_id}: {e}"
146-
)
147-
try:
148-
deleted_count = self._cleanup_orphaned_messages_sync(conn)
149-
if deleted_count:
150-
conn.commit()
151-
else:
152-
conn.rollback()
153-
except Exception as cleanup_error:
154-
conn.rollback()
155-
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
143+
self._logger.exception("Failed to add items for session %s", self.session_id)
144+
raise
156145

157146
await asyncio.to_thread(_add_items_sync)
158147

@@ -367,16 +356,16 @@ def _add_structure_sync():
367356

368357
try:
369358
await asyncio.to_thread(_add_structure_sync)
370-
except Exception as e:
371-
self._logger.error(
372-
f"Failed to add structure metadata for session {self.session_id}: {e}"
359+
except Exception:
360+
self._logger.exception(
361+
"Failed to add structure metadata for session %s", self.session_id
373362
)
374-
# Try to clean up any orphaned messages to maintain consistency
363+
# Try to clean up any orphaned messages to maintain consistency.
375364
try:
376365
await self._cleanup_orphaned_messages()
377-
except Exception as cleanup_error:
378-
self._logger.error(f"Failed to cleanup orphaned messages: {cleanup_error}")
379-
# Don't re-raise - structure metadata is supplementary
366+
except Exception:
367+
self._logger.exception("Failed to cleanup orphaned messages")
368+
raise
380369

381370
def _insert_structure_metadata(
382371
self,
@@ -469,8 +458,8 @@ def _insert_structure_metadata(
469458
async def _cleanup_orphaned_messages(self) -> int:
470459
"""Remove messages that exist in the configured message table but not in message_structure.
471460
472-
This can happen if _add_structure_metadata fails after super().add_items() succeeds.
473-
Used for maintaining data consistency.
461+
This can happen for rows written by older or non-atomic structure metadata paths.
462+
`add_items()` writes message rows and structure metadata in a single transaction.
474463
"""
475464

476465
def _cleanup_sync():

tests/extensions/memory/test_advanced_sqlite_session.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,52 @@ def create_mock_run_result(usage: Usage | None = None, agent: Agent | None = Non
8080
)
8181

8282

83+
class FailingOnceStructureMetadataSession(AdvancedSQLiteSession):
84+
"""Advanced session test double that fails the next structure metadata write."""
85+
86+
def __init__(self, **kwargs: Any):
87+
super().__init__(**kwargs)
88+
self.fail_structure_metadata_once = True
89+
90+
def _insert_structure_metadata(
91+
self,
92+
conn: Any,
93+
items: list[TResponseInputItem],
94+
) -> None:
95+
if self.fail_structure_metadata_once:
96+
self.fail_structure_metadata_once = False
97+
raise RuntimeError("structure metadata failed")
98+
super()._insert_structure_metadata(conn, items)
99+
100+
101+
class PartiallyFailingStructureMetadataSession(AdvancedSQLiteSession):
102+
"""Advanced session test double that fails after writing one structure row."""
103+
104+
def _insert_structure_metadata(
105+
self,
106+
conn: Any,
107+
items: list[TResponseInputItem],
108+
) -> None:
109+
cursor = conn.execute(
110+
f"SELECT id FROM {self.messages_table} WHERE session_id = ? ORDER BY id ASC LIMIT 1",
111+
(self.session_id,),
112+
)
113+
row = cursor.fetchone()
114+
if row is None:
115+
raise RuntimeError("no inserted message id found")
116+
117+
conn.execute(
118+
"""
119+
INSERT INTO message_structure
120+
(session_id, message_id, branch_id, message_type, sequence_number,
121+
user_turn_number, branch_turn_number, tool_name)
122+
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
123+
""",
124+
(self.session_id, row[0], self._current_branch_id, "user", 1, 1, 1, None),
125+
)
126+
raise RuntimeError("structure metadata failed after partial write")
127+
128+
83129
async def test_advanced_session_basic_functionality(agent: Agent):
84130
"""Test basic AdvancedSQLiteSession functionality."""
85131
session_id = "advanced_test"
@@ -147,6 +193,133 @@ async def test_advanced_session_respects_custom_table_names():
147193
session.close()
148194

149195

196+
async def test_add_items_rolls_back_messages_when_structure_metadata_fails():
197+
"""Failed structure metadata writes should not leave invisible message rows."""
198+
session = FailingOnceStructureMetadataSession(
199+
session_id="advanced_add_items_rollback",
200+
create_tables=True,
201+
)
202+
items: list[TResponseInputItem] = [{"role": "user", "content": "not saved"}]
203+
204+
try:
205+
with pytest.raises(RuntimeError, match="structure metadata failed"):
206+
await session.add_items(items)
207+
208+
assert await session.get_items() == []
209+
210+
with session._locked_connection() as conn:
211+
message_count = conn.execute(
212+
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
213+
(session.session_id,),
214+
).fetchone()[0]
215+
structure_count = conn.execute(
216+
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
217+
(session.session_id,),
218+
).fetchone()[0]
219+
220+
assert message_count == 0
221+
assert structure_count == 0
222+
finally:
223+
session.close()
224+
225+
226+
async def test_add_items_can_retry_after_structure_metadata_failure():
227+
"""Retrying after a metadata failure should persist the batch exactly once."""
228+
session = FailingOnceStructureMetadataSession(
229+
session_id="advanced_add_items_retry",
230+
create_tables=True,
231+
)
232+
items: list[TResponseInputItem] = [{"role": "user", "content": "saved once"}]
233+
234+
try:
235+
with pytest.raises(RuntimeError, match="structure metadata failed"):
236+
await session.add_items(items)
237+
238+
await session.add_items(items)
239+
240+
assert await session.get_items() == items
241+
242+
with session._locked_connection() as conn:
243+
message_count = conn.execute(
244+
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
245+
(session.session_id,),
246+
).fetchone()[0]
247+
structure_count = conn.execute(
248+
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
249+
(session.session_id,),
250+
).fetchone()[0]
251+
252+
assert message_count == 1
253+
assert structure_count == 1
254+
finally:
255+
session.close()
256+
257+
258+
async def test_add_items_failure_preserves_existing_history():
259+
"""A failed batch should not roll back or hide previously committed messages."""
260+
session = FailingOnceStructureMetadataSession(
261+
session_id="advanced_add_items_existing_history",
262+
create_tables=True,
263+
)
264+
existing_items: list[TResponseInputItem] = [{"role": "user", "content": "already saved"}]
265+
failed_items: list[TResponseInputItem] = [{"role": "assistant", "content": "not saved"}]
266+
267+
try:
268+
session.fail_structure_metadata_once = False
269+
await session.add_items(existing_items)
270+
271+
session.fail_structure_metadata_once = True
272+
with pytest.raises(RuntimeError, match="structure metadata failed"):
273+
await session.add_items(failed_items)
274+
275+
assert await session.get_items() == existing_items
276+
277+
with session._locked_connection() as conn:
278+
message_count = conn.execute(
279+
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
280+
(session.session_id,),
281+
).fetchone()[0]
282+
structure_count = conn.execute(
283+
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
284+
(session.session_id,),
285+
).fetchone()[0]
286+
287+
assert message_count == 1
288+
assert structure_count == 1
289+
finally:
290+
session.close()
291+
292+
293+
async def test_add_items_rolls_back_partial_structure_metadata_write():
294+
"""Partial metadata writes should roll back with the message rows in the same batch."""
295+
session = PartiallyFailingStructureMetadataSession(
296+
session_id="advanced_add_items_partial_metadata",
297+
create_tables=True,
298+
)
299+
items: list[TResponseInputItem] = [{"role": "user", "content": "not saved"}]
300+
301+
try:
302+
with pytest.raises(RuntimeError, match="structure metadata failed after partial write"):
303+
await session.add_items(items)
304+
305+
assert await session.get_items() == []
306+
307+
with session._locked_connection() as conn:
308+
message_count = conn.execute(
309+
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
310+
(session.session_id,),
311+
).fetchone()[0]
312+
structure_count = conn.execute(
313+
"SELECT COUNT(*) FROM message_structure WHERE session_id = ?",
314+
(session.session_id,),
315+
).fetchone()[0]
316+
317+
assert message_count == 0
318+
assert structure_count == 0
319+
finally:
320+
session.close()
321+
322+
150323
async def test_message_structure_tracking(agent: Agent):
151324
"""Test that message structure is properly tracked."""
152325
session_id = "structure_test"

0 commit comments

Comments
 (0)