Skip to content

Commit 5732480

Browse files
committed
fix(sessions): clean up branch-only messages on delete
1 parent 92e014a commit 5732480

2 files changed

Lines changed: 181 additions & 22 deletions

File tree

src/agents/extensions/memory/advanced_sqlite_session.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -487,30 +487,22 @@ def _cleanup_sync():
487487

488488
def _cleanup_orphaned_messages_sync(self, conn: sqlite3.Connection) -> int:
489489
with closing(conn.cursor()) as cursor:
490-
# Find messages without structure metadata.
491490
cursor.execute(
492491
f"""
493-
SELECT am.id
494-
FROM {self.messages_table} am
495-
LEFT JOIN message_structure ms ON am.id = ms.message_id
496-
WHERE am.session_id = ? AND ms.message_id IS NULL
497-
""",
498-
(self.session_id,),
499-
)
500-
501-
orphaned_ids = [row[0] for row in cursor.fetchall()]
502-
503-
if not orphaned_ids:
504-
return 0
505-
506-
placeholders = ",".join("?" * len(orphaned_ids))
507-
cursor.execute(
508-
f"DELETE FROM {self.messages_table} WHERE id IN ({placeholders})",
509-
orphaned_ids,
492+
DELETE FROM {self.messages_table}
493+
WHERE session_id = ?
494+
AND id NOT IN (
495+
SELECT message_id
496+
FROM message_structure ms
497+
WHERE ms.session_id = ?
498+
)
499+
""",
500+
(self.session_id, self.session_id),
510501
)
511502

512503
deleted_count = cursor.rowcount
513-
self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
504+
if deleted_count:
505+
self._logger.info(f"Cleaned up {deleted_count} orphaned messages")
514506
return deleted_count
515507

516508
def _classify_message_type(self, item: TResponseInputItem) -> str:
@@ -786,14 +778,19 @@ def _delete_sync():
786778

787779
structure_deleted = cursor.rowcount
788780

781+
orphaned_messages_deleted = self._cleanup_orphaned_messages_sync(conn)
782+
789783
conn.commit()
790784

791-
return usage_deleted, structure_deleted
785+
return usage_deleted, structure_deleted, orphaned_messages_deleted
792786

793-
usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync)
787+
usage_deleted, structure_deleted, orphaned_messages_deleted = await asyncio.to_thread(
788+
_delete_sync
789+
)
794790

795791
self._logger.info(
796-
f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501
792+
f"Deleted branch '{branch_id}': {structure_deleted} message entries, "
793+
f"{usage_deleted} usage entries, {orphaned_messages_deleted} orphaned messages"
797794
)
798795

799796
async def list_branches(self) -> list[dict[str, Any]]:

tests/extensions/memory/test_advanced_sqlite_session.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,168 @@ async def test_branching_functionality(agent: Agent):
442442
session.close()
443443

444444

445+
async def test_delete_branch_removes_branch_only_messages():
446+
"""Deleting a branch should not leave unreferenced branch-only messages behind."""
447+
session_id = "branch_delete_cleanup_test"
448+
session = AdvancedSQLiteSession(session_id=session_id, create_tables=True)
449+
450+
main_items: list[TResponseInputItem] = [
451+
{"role": "user", "content": "First question"},
452+
{"role": "assistant", "content": "First answer"},
453+
{"role": "user", "content": "Second question"},
454+
{"role": "assistant", "content": "Second answer"},
455+
]
456+
await session.add_items(main_items)
457+
458+
await session.create_branch_from_turn(2, "cleanup_branch")
459+
branch_items: list[TResponseInputItem] = [
460+
{"role": "user", "content": "Branch-only question"},
461+
{"role": "assistant", "content": "Branch-only answer"},
462+
]
463+
await session.add_items(branch_items)
464+
465+
await session.delete_branch("cleanup_branch", force=True)
466+
467+
with session._locked_connection() as conn:
468+
rows = conn.execute(
469+
f"""
470+
SELECT message_data
471+
FROM {session.messages_table}
472+
WHERE session_id = ?
473+
ORDER BY id
474+
""",
475+
(session.session_id,),
476+
).fetchall()
477+
478+
contents = [json.loads(message_data)["content"] for (message_data,) in rows]
479+
assert contents == [
480+
"First question",
481+
"First answer",
482+
"Second question",
483+
"Second answer",
484+
]
485+
assert await session.get_items(branch_id="main") == main_items
486+
487+
session.close()
488+
489+
490+
async def test_delete_branch_keeps_messages_still_referenced_by_another_branch():
491+
"""Deleting one branch should keep messages inherited by a surviving branch."""
492+
session = AdvancedSQLiteSession(
493+
session_id="branch_delete_shared_descendant_test",
494+
create_tables=True,
495+
)
496+
497+
main_items: list[TResponseInputItem] = [
498+
{"role": "user", "content": "Main first question"},
499+
{"role": "assistant", "content": "Main first answer"},
500+
{"role": "user", "content": "Main second question"},
501+
{"role": "assistant", "content": "Main second answer"},
502+
]
503+
branch_a_shared_items: list[TResponseInputItem] = [
504+
{"role": "user", "content": "Branch A shared question"},
505+
{"role": "assistant", "content": "Branch A shared answer"},
506+
]
507+
branch_a_only_items: list[TResponseInputItem] = [
508+
{"role": "user", "content": "Branch A only question"},
509+
{"role": "assistant", "content": "Branch A only answer"},
510+
]
511+
512+
try:
513+
await session.add_items(main_items)
514+
await session.create_branch_from_turn(2, "branch_a")
515+
await session.add_items(branch_a_shared_items + branch_a_only_items)
516+
517+
await session.create_branch_from_turn(3, "branch_b")
518+
await session.delete_branch("branch_a")
519+
520+
with session._locked_connection() as conn:
521+
rows = conn.execute(
522+
f"""
523+
SELECT message_data
524+
FROM {session.messages_table}
525+
WHERE session_id = ?
526+
ORDER BY id
527+
""",
528+
(session.session_id,),
529+
).fetchall()
530+
531+
contents = [json.loads(message_data)["content"] for (message_data,) in rows]
532+
assert "Branch A shared question" in contents
533+
assert "Branch A shared answer" in contents
534+
assert "Branch A only question" not in contents
535+
assert "Branch A only answer" not in contents
536+
assert await session.get_items(branch_id="branch_b") == [
537+
*main_items[:2],
538+
*branch_a_shared_items,
539+
]
540+
finally:
541+
session.close()
542+
543+
544+
async def test_orphan_cleanup_uses_set_based_delete_for_many_messages():
545+
"""Orphan cleanup should not build one DELETE parameter per orphaned row."""
546+
547+
class RecordingCursor:
548+
def __init__(self, cursor: Any, connection: "RecordingConnection") -> None:
549+
self._cursor = cursor
550+
self._connection = connection
551+
552+
@property
553+
def rowcount(self) -> int:
554+
return cast(int, self._cursor.rowcount)
555+
556+
def execute(self, sql: str, parameters: Any = None) -> Any:
557+
normalized_sql = " ".join(sql.split()).upper()
558+
if normalized_sql.startswith("DELETE"):
559+
self._connection.delete_parameter_counts.append(len(parameters or ()))
560+
if parameters is None:
561+
return self._cursor.execute(sql)
562+
return self._cursor.execute(sql, parameters)
563+
564+
def fetchall(self) -> Any:
565+
return self._cursor.fetchall()
566+
567+
def close(self) -> None:
568+
self._cursor.close()
569+
570+
class RecordingConnection:
571+
def __init__(self, conn: Any) -> None:
572+
self._conn = conn
573+
self.delete_parameter_counts: list[int] = []
574+
575+
def cursor(self) -> RecordingCursor:
576+
return RecordingCursor(self._conn.cursor(), self)
577+
578+
session = AdvancedSQLiteSession(
579+
session_id="branch_delete_many_orphans_cleanup",
580+
create_tables=True,
581+
)
582+
orphan_items: list[TResponseInputItem] = [
583+
{"role": "user", "content": f"orphan {i}"} for i in range(1200)
584+
]
585+
586+
try:
587+
with session._locked_connection() as conn:
588+
session._insert_items(conn, orphan_items)
589+
conn.commit()
590+
591+
recording_conn = RecordingConnection(conn)
592+
deleted_count = session._cleanup_orphaned_messages_sync(cast(Any, recording_conn))
593+
conn.commit()
594+
595+
remaining_count = conn.execute(
596+
f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?",
597+
(session.session_id,),
598+
).fetchone()[0]
599+
600+
assert deleted_count == len(orphan_items)
601+
assert remaining_count == 0
602+
assert recording_conn.delete_parameter_counts == [2]
603+
finally:
604+
session.close()
605+
606+
445607
async def test_get_conversation_turns():
446608
"""Test get_conversation_turns functionality."""
447609
session_id = "conversation_turns_test"

0 commit comments

Comments
 (0)