@@ -80,6 +80,52 @@ def create_mock_run_result(usage: Usage | None = None, agent: Agent | None = Non
8080 )
8181
8282
83+ class FailingOnceStructureMetadataSession (AdvancedSQLiteSession ):
84+ """Advanced session test double that fails the next structure metadata write."""
85+
86+ def __init__ (self , ** kwargs : Any ):
87+ super ().__init__ (** kwargs )
88+ self .fail_structure_metadata_once = True
89+
90+ def _insert_structure_metadata (
91+ self ,
92+ conn : Any ,
93+ items : list [TResponseInputItem ],
94+ ) -> None :
95+ if self .fail_structure_metadata_once :
96+ self .fail_structure_metadata_once = False
97+ raise RuntimeError ("structure metadata failed" )
98+ super ()._insert_structure_metadata (conn , items )
99+
100+
101+ class PartiallyFailingStructureMetadataSession (AdvancedSQLiteSession ):
102+ """Advanced session test double that fails after writing one structure row."""
103+
104+ def _insert_structure_metadata (
105+ self ,
106+ conn : Any ,
107+ items : list [TResponseInputItem ],
108+ ) -> None :
109+ cursor = conn .execute (
110+ f"SELECT id FROM { self .messages_table } WHERE session_id = ? ORDER BY id ASC LIMIT 1" ,
111+ (self .session_id ,),
112+ )
113+ row = cursor .fetchone ()
114+ if row is None :
115+ raise RuntimeError ("no inserted message id found" )
116+
117+ conn .execute (
118+ """
119+ INSERT INTO message_structure
120+ (session_id, message_id, branch_id, message_type, sequence_number,
121+ user_turn_number, branch_turn_number, tool_name)
122+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
123+ """ ,
124+ (self .session_id , row [0 ], self ._current_branch_id , "user" , 1 , 1 , 1 , None ),
125+ )
126+ raise RuntimeError ("structure metadata failed after partial write" )
127+
128+
83129async def test_advanced_session_basic_functionality (agent : Agent ):
84130 """Test basic AdvancedSQLiteSession functionality."""
85131 session_id = "advanced_test"
@@ -147,6 +193,133 @@ async def test_advanced_session_respects_custom_table_names():
147193 session .close ()
148194
149195
196+ async def test_add_items_rolls_back_messages_when_structure_metadata_fails ():
197+ """Failed structure metadata writes should not leave invisible message rows."""
198+ session = FailingOnceStructureMetadataSession (
199+ session_id = "advanced_add_items_rollback" ,
200+ create_tables = True ,
201+ )
202+ items : list [TResponseInputItem ] = [{"role" : "user" , "content" : "not saved" }]
203+
204+ try :
205+ with pytest .raises (RuntimeError , match = "structure metadata failed" ):
206+ await session .add_items (items )
207+
208+ assert await session .get_items () == []
209+
210+ with session ._locked_connection () as conn :
211+ message_count = conn .execute (
212+ f"SELECT COUNT(*) FROM { session .messages_table } WHERE session_id = ?" ,
213+ (session .session_id ,),
214+ ).fetchone ()[0 ]
215+ structure_count = conn .execute (
216+ "SELECT COUNT(*) FROM message_structure WHERE session_id = ?" ,
217+ (session .session_id ,),
218+ ).fetchone ()[0 ]
219+
220+ assert message_count == 0
221+ assert structure_count == 0
222+ finally :
223+ session .close ()
224+
225+
226+ async def test_add_items_can_retry_after_structure_metadata_failure ():
227+ """Retrying after a metadata failure should persist the batch exactly once."""
228+ session = FailingOnceStructureMetadataSession (
229+ session_id = "advanced_add_items_retry" ,
230+ create_tables = True ,
231+ )
232+ items : list [TResponseInputItem ] = [{"role" : "user" , "content" : "saved once" }]
233+
234+ try :
235+ with pytest .raises (RuntimeError , match = "structure metadata failed" ):
236+ await session .add_items (items )
237+
238+ await session .add_items (items )
239+
240+ assert await session .get_items () == items
241+
242+ with session ._locked_connection () as conn :
243+ message_count = conn .execute (
244+ f"SELECT COUNT(*) FROM { session .messages_table } WHERE session_id = ?" ,
245+ (session .session_id ,),
246+ ).fetchone ()[0 ]
247+ structure_count = conn .execute (
248+ "SELECT COUNT(*) FROM message_structure WHERE session_id = ?" ,
249+ (session .session_id ,),
250+ ).fetchone ()[0 ]
251+
252+ assert message_count == 1
253+ assert structure_count == 1
254+ finally :
255+ session .close ()
256+
257+
258+ async def test_add_items_failure_preserves_existing_history ():
259+ """A failed batch should not roll back or hide previously committed messages."""
260+ session = FailingOnceStructureMetadataSession (
261+ session_id = "advanced_add_items_existing_history" ,
262+ create_tables = True ,
263+ )
264+ existing_items : list [TResponseInputItem ] = [{"role" : "user" , "content" : "already saved" }]
265+ failed_items : list [TResponseInputItem ] = [{"role" : "assistant" , "content" : "not saved" }]
266+
267+ try :
268+ session .fail_structure_metadata_once = False
269+ await session .add_items (existing_items )
270+
271+ session .fail_structure_metadata_once = True
272+ with pytest .raises (RuntimeError , match = "structure metadata failed" ):
273+ await session .add_items (failed_items )
274+
275+ assert await session .get_items () == existing_items
276+
277+ with session ._locked_connection () as conn :
278+ message_count = conn .execute (
279+ f"SELECT COUNT(*) FROM { session .messages_table } WHERE session_id = ?" ,
280+ (session .session_id ,),
281+ ).fetchone ()[0 ]
282+ structure_count = conn .execute (
283+ "SELECT COUNT(*) FROM message_structure WHERE session_id = ?" ,
284+ (session .session_id ,),
285+ ).fetchone ()[0 ]
286+
287+ assert message_count == 1
288+ assert structure_count == 1
289+ finally :
290+ session .close ()
291+
292+
293+ async def test_add_items_rolls_back_partial_structure_metadata_write ():
294+ """Partial metadata writes should roll back with the message rows in the same batch."""
295+ session = PartiallyFailingStructureMetadataSession (
296+ session_id = "advanced_add_items_partial_metadata" ,
297+ create_tables = True ,
298+ )
299+ items : list [TResponseInputItem ] = [{"role" : "user" , "content" : "not saved" }]
300+
301+ try :
302+ with pytest .raises (RuntimeError , match = "structure metadata failed after partial write" ):
303+ await session .add_items (items )
304+
305+ assert await session .get_items () == []
306+
307+ with session ._locked_connection () as conn :
308+ message_count = conn .execute (
309+ f"SELECT COUNT(*) FROM { session .messages_table } WHERE session_id = ?" ,
310+ (session .session_id ,),
311+ ).fetchone ()[0 ]
312+ structure_count = conn .execute (
313+ "SELECT COUNT(*) FROM message_structure WHERE session_id = ?" ,
314+ (session .session_id ,),
315+ ).fetchone ()[0 ]
316+
317+ assert message_count == 0
318+ assert structure_count == 0
319+ finally :
320+ session .close ()
321+
322+
150323async def test_message_structure_tracking (agent : Agent ):
151324 """Test that message structure is properly tracked."""
152325 session_id = "structure_test"
0 commit comments