Skip to content

Commit f55a5eb

Browse files
authored
Merge pull request #44 from beersoccer/fix/checkpoint-lock-reliability
Fix/checkpoint lock reliability
2 parents 07d6709 + a66936c commit f55a5eb

7 files changed

Lines changed: 1859 additions & 82 deletions

File tree

tests/unit/tools/test_forget_memories.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
from datetime import UTC, datetime, timedelta
45
from unittest.mock import MagicMock, patch
56

6-
from tools.forget_memories import ForgetMemoriesTool
7+
from tools.forget_memories import ForgetMemoriesTool, _clean_expired_locks
78

89

910
def _extract_json_payload(message: object) -> dict:
@@ -145,3 +146,124 @@ def test_forget_memories_deletes_old_checkpoints_in_real_run() -> None:
145146
payload = _extract_json_payload(messages[0])
146147
assert payload.get("status") == "SUCCESS"
147148
assert payload.get("results", {}).get("checkpoints_cleaned") == 2
149+
150+
151+
def _make_lock_memory(
152+
mem_id: str,
153+
holder_id: str,
154+
acquired_at: str,
155+
ttl_seconds: int,
156+
) -> dict:
157+
"""Build a fake lock memory entry for testing."""
158+
lock_data = {
159+
"lock_id": "lock:extraction:u1:*",
160+
"holder_id": holder_id,
161+
"acquired_at": acquired_at,
162+
"ttl_seconds": ttl_seconds,
163+
}
164+
return {
165+
"id": mem_id,
166+
"memory": json.dumps(lock_data, ensure_ascii=False),
167+
"metadata": {
168+
"__internal": True,
169+
"internal_type": "distributed_lock",
170+
},
171+
}
172+
173+
174+
def test_clean_expired_locks_deletes_expired_only() -> None:
175+
"""_clean_expired_locks should delete expired locks and keep active ones."""
176+
now = datetime.now(UTC)
177+
expired_time = (now - timedelta(seconds=7200)).isoformat()
178+
active_time = now.isoformat()
179+
180+
mem = MagicMock()
181+
mem.get_all.return_value = {
182+
"results": [
183+
_make_lock_memory("lock-1", "run-old", expired_time, 3600),
184+
_make_lock_memory("lock-2", "run-active", active_time, 3600),
185+
]
186+
}
187+
188+
count = _clean_expired_locks(
189+
mem, user_id="u1", app_id=None, dry_run=False, request_id="req-1"
190+
)
191+
192+
assert count == 1
193+
mem.delete.assert_called_once_with("lock-1")
194+
195+
196+
def test_clean_expired_locks_dry_run_does_not_delete() -> None:
197+
"""Dry run should count but not delete."""
198+
expired_time = (datetime.now(UTC) - timedelta(seconds=7200)).isoformat()
199+
200+
mem = MagicMock()
201+
mem.get_all.return_value = {
202+
"results": [
203+
_make_lock_memory("lock-1", "run-old", expired_time, 3600),
204+
]
205+
}
206+
207+
count = _clean_expired_locks(
208+
mem, user_id="u1", app_id=None, dry_run=True, request_id="req-1"
209+
)
210+
211+
assert count == 1
212+
mem.delete.assert_not_called()
213+
214+
215+
def test_clean_expired_locks_corrupted_record() -> None:
216+
"""Corrupted lock records (unparseable JSON) should be cleaned up."""
217+
mem = MagicMock()
218+
mem.get_all.return_value = {
219+
"results": [
220+
{
221+
"id": "lock-corrupt",
222+
"memory": "not-valid-json",
223+
"metadata": {
224+
"__internal": True,
225+
"internal_type": "distributed_lock",
226+
},
227+
}
228+
]
229+
}
230+
231+
count = _clean_expired_locks(
232+
mem, user_id="u1", app_id=None, dry_run=False, request_id="req-1"
233+
)
234+
235+
assert count == 1
236+
mem.delete.assert_called_once_with("lock-corrupt")
237+
238+
239+
def test_forget_memories_includes_locks_cleaned_in_result() -> None:
240+
"""Integration: ForgetMemoriesTool result should include locks_cleaned."""
241+
tool = _build_tool()
242+
mem = MagicMock()
243+
244+
# Call 1: regular memories, Call 2: checkpoints, Call 3: locks
245+
expired_time = (datetime.now(UTC) - timedelta(seconds=7200)).isoformat()
246+
mem.get_all.side_effect = [
247+
{"results": []}, # no regular memories
248+
{"results": []}, # no checkpoints
249+
{
250+
"results": [
251+
_make_lock_memory("lock-1", "old-run", expired_time, 3600),
252+
]
253+
}, # one expired lock
254+
]
255+
client = MagicMock(memory=mem)
256+
mgr = MagicMock()
257+
mgr.load.return_value = ("log-1", {})
258+
259+
with (
260+
patch("tools.forget_memories.init_request_context", return_value=("req-1", 0.0)),
261+
patch("tools.forget_memories.validate_user_id", return_value="u1"),
262+
patch("tools.forget_memories.get_sync_client", return_value=client),
263+
patch("tools.forget_memories.SyncAccessLogManager", return_value=mgr),
264+
):
265+
messages = list(tool._invoke({"user_id": "u1", "dry_run": False}))
266+
267+
payload = _extract_json_payload(messages[0])
268+
assert payload.get("status") == "SUCCESS"
269+
assert payload.get("results", {}).get("locks_cleaned") == 1

tests/unit/utils/test_checkpoint.py

Lines changed: 223 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import threading
35
from typing import Any
46

57
import pytest
68

79
from utils.checkpoint import (
810
CHECKPOINT_VERSION,
11+
AsyncCheckpointManager,
912
SyncCheckpointManager,
1013
checkpoint_filters,
1114
checkpoint_metadata,
1215
)
13-
from utils.extraction import UserCheckpoint
16+
from utils.extraction import ConversationCheckpoint, UserCheckpoint
1417

1518

1619
class FakeMemory:
@@ -129,18 +132,232 @@ def test_save_checkpoint_add_and_update(monkeypatch: pytest.MonkeyPatch) -> None
129132
assert mem.added[0]["user_id"] == "u1"
130133
assert mem.added[0]["agent_id"] is None
131134

132-
# update existing (uses delete+add, not update)
133-
ok2, same_id = mgr.save(
135+
# update existing (uses add-first-then-delete)
136+
ok2, new_id2 = mgr.save(
134137
checkpoint_id=new_id,
135138
user_id="u1",
136139
app_id=None,
137140
checkpoint=UserCheckpoint(),
138141
)
139142
assert ok2 is True
140-
assert same_id == new_id
141-
# save uses delete+add instead of update to avoid embedding
143+
assert new_id2 is not None
144+
# save adds new first, then deletes old
142145
assert new_id in mem.deleted # Old checkpoint deleted
143146
assert len(mem.added) == 2 # New checkpoint added
144147

145148

146-
149+
def test_save_add_before_delete_order() -> None:
150+
"""Verify that save() adds the new checkpoint BEFORE deleting the old one."""
151+
mem = FakeMemory()
152+
mgr = SyncCheckpointManager(mem)
153+
154+
# Create initial checkpoint
155+
ok, initial_id = mgr.save(
156+
checkpoint_id=None,
157+
user_id="u1",
158+
app_id=None,
159+
checkpoint=UserCheckpoint(),
160+
)
161+
assert ok and initial_id
162+
163+
# Track operation order
164+
operations: list[str] = []
165+
original_add = mem.add
166+
original_delete = mem.delete
167+
168+
def tracking_add(text: str, **kwargs: Any) -> dict[str, Any]:
169+
operations.append("add")
170+
return original_add(text, **kwargs)
171+
172+
def tracking_delete(memory_id: str) -> dict[str, Any]:
173+
operations.append("delete")
174+
return original_delete(memory_id)
175+
176+
mem.add = tracking_add # type: ignore[assignment]
177+
mem.delete = tracking_delete # type: ignore[assignment]
178+
179+
# Update checkpoint
180+
ok2, new_id = mgr.save(
181+
checkpoint_id=initial_id,
182+
user_id="u1",
183+
app_id=None,
184+
checkpoint=UserCheckpoint(),
185+
)
186+
assert ok2 is True
187+
# add must come before delete
188+
assert operations == ["add", "delete"]
189+
190+
191+
def test_save_keeps_old_checkpoint_on_add_failure() -> None:
192+
"""If add() fails, old checkpoint should NOT be deleted."""
193+
mem = FakeMemory()
194+
mgr = SyncCheckpointManager(mem)
195+
196+
# Create initial checkpoint
197+
ok, initial_id = mgr.save(
198+
checkpoint_id=None,
199+
user_id="u1",
200+
app_id=None,
201+
checkpoint=UserCheckpoint(),
202+
)
203+
assert ok and initial_id
204+
205+
# Make add return no ID (simulating failure)
206+
mem.add = lambda text, **kwargs: {"results": []} # type: ignore[assignment]
207+
208+
ok2, new_id = mgr.save(
209+
checkpoint_id=initial_id,
210+
user_id="u1",
211+
app_id=None,
212+
checkpoint=UserCheckpoint(),
213+
)
214+
# add returned no ID, so old checkpoint must NOT be deleted
215+
assert initial_id not in mem.deleted
216+
217+
218+
class AsyncFakeMemory:
219+
"""Async-compatible fake Mem0 memory for testing."""
220+
221+
def __init__(self) -> None:
222+
self._store: list[dict[str, Any]] = []
223+
self.added: list[dict[str, Any]] = []
224+
self.deleted: list[str] = []
225+
226+
def _match(self, md: dict[str, Any], filt: Any) -> bool:
227+
if not filt or not isinstance(filt, dict):
228+
return True
229+
# Leaf filter: {"key": value} or {"key": {"eq": value}}
230+
for k, v in filt.items():
231+
if isinstance(v, dict) and "eq" in v:
232+
if md.get(k) != v.get("eq"):
233+
return False
234+
else:
235+
if md.get(k) != v:
236+
return False
237+
return True
238+
239+
async def get_all(self, **kwargs: Any) -> dict[str, Any]:
240+
filt = kwargs.get("filters")
241+
out: list[dict[str, Any]] = []
242+
for item in self._store:
243+
md = item.get("metadata") or {}
244+
if self._match(md, filt):
245+
out.append(item)
246+
return {"results": out}
247+
248+
async def add(self, text: str, **kwargs: Any) -> dict[str, Any]:
249+
md = kwargs.get("metadata") or {}
250+
new_id = f"cp_{len(self._store)+1}"
251+
self._store.append({"id": new_id, "memory": text, "metadata": md})
252+
self.added.append({"id": new_id, "text": text})
253+
return {"results": [{"id": new_id, "event": "ADD"}]}
254+
255+
async def delete(self, memory_id: str) -> dict[str, Any]:
256+
self.deleted.append(memory_id)
257+
self._store = [x for x in self._store if x.get("id") != memory_id]
258+
return {"message": "deleted"}
259+
260+
261+
def _run_async(coro: Any) -> Any:
262+
"""Run a coroutine in a dedicated thread to avoid event loop conflicts with pytest-asyncio."""
263+
result: list[Any] = []
264+
exc: list[BaseException] = []
265+
266+
def _thread_target() -> None:
267+
# Clear any inherited running-loop state from the parent thread.
268+
# In pytest-asyncio AUTO mode the main thread runs inside an event
269+
# loop, and CPython's C-level _get_running_loop() can inherit that
270+
# state in child threads, causing run_until_complete() to raise
271+
# "Cannot run the event loop while another loop is running".
272+
asyncio.events._set_running_loop(None) # type: ignore[attr-defined]
273+
asyncio.set_event_loop(None)
274+
loop = asyncio.new_event_loop()
275+
asyncio.set_event_loop(loop)
276+
try:
277+
result.append(loop.run_until_complete(coro))
278+
except BaseException as e: # noqa: BLE001
279+
exc.append(e)
280+
finally:
281+
loop.close()
282+
asyncio.set_event_loop(None)
283+
284+
t = threading.Thread(target=_thread_target)
285+
t.start()
286+
t.join()
287+
if exc:
288+
raise exc[0]
289+
return result[0] if result else None
290+
291+
292+
def test_async_load_restores_resume_fields() -> None:
293+
"""AsyncCheckpointManager.load() must restore resume_* fields."""
294+
295+
async def _run() -> None:
296+
mem = AsyncFakeMemory()
297+
mgr = AsyncCheckpointManager(mem)
298+
299+
cp = UserCheckpoint(
300+
conversations={
301+
"conv1": ConversationCheckpoint(
302+
last_processed_message_id="msg-100",
303+
processed_range_start="2026-01-01T00:00:00Z",
304+
processed_range_end="2026-01-02T00:00:00Z",
305+
),
306+
},
307+
resume_conversation_cursor="conv1",
308+
resume_run_at="2026-04-01T12:00:00Z",
309+
resume_start_time="2026-03-01T00:00:00Z",
310+
)
311+
312+
ok, cp_id = await mgr.save(
313+
checkpoint_id=None, user_id="u1", app_id=None, checkpoint=cp
314+
)
315+
assert ok and cp_id
316+
317+
loaded_id, loaded_cp = await mgr.load(user_id="u1", app_id=None)
318+
assert loaded_id == cp_id
319+
assert loaded_cp is not None
320+
assert loaded_cp.resume_conversation_cursor == "conv1"
321+
assert loaded_cp.resume_run_at == "2026-04-01T12:00:00Z"
322+
assert loaded_cp.resume_start_time == "2026-03-01T00:00:00Z"
323+
assert "conv1" in loaded_cp.conversations
324+
assert loaded_cp.conversations["conv1"].last_processed_message_id == "msg-100"
325+
326+
_run_async(_run())
327+
328+
329+
def test_async_save_add_before_delete() -> None:
330+
"""Async save() must add new checkpoint before deleting old."""
331+
332+
async def _run() -> None:
333+
mem = AsyncFakeMemory()
334+
mgr = AsyncCheckpointManager(mem)
335+
336+
ok, initial_id = await mgr.save(
337+
checkpoint_id=None, user_id="u1", app_id=None, checkpoint=UserCheckpoint()
338+
)
339+
assert ok and initial_id
340+
341+
operations: list[str] = []
342+
original_add = mem.add
343+
original_delete = mem.delete
344+
345+
async def tracking_add(text: str, **kwargs: Any) -> dict[str, Any]:
346+
operations.append("add")
347+
return await original_add(text, **kwargs)
348+
349+
async def tracking_delete(memory_id: str) -> dict[str, Any]:
350+
operations.append("delete")
351+
return await original_delete(memory_id)
352+
353+
mem.add = tracking_add # type: ignore[assignment]
354+
mem.delete = tracking_delete # type: ignore[assignment]
355+
356+
ok2, new_id = await mgr.save(
357+
checkpoint_id=initial_id, user_id="u1", app_id=None, checkpoint=UserCheckpoint()
358+
)
359+
assert ok2 is True
360+
assert operations == ["add", "delete"]
361+
362+
_run_async(_run())
363+

0 commit comments

Comments
 (0)