|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import asyncio |
| 4 | +import threading |
3 | 5 | from typing import Any |
4 | 6 |
|
5 | 7 | import pytest |
6 | 8 |
|
7 | 9 | from utils.checkpoint import ( |
8 | 10 | CHECKPOINT_VERSION, |
| 11 | + AsyncCheckpointManager, |
9 | 12 | SyncCheckpointManager, |
10 | 13 | checkpoint_filters, |
11 | 14 | checkpoint_metadata, |
12 | 15 | ) |
13 | | -from utils.extraction import UserCheckpoint |
| 16 | +from utils.extraction import ConversationCheckpoint, UserCheckpoint |
14 | 17 |
|
15 | 18 |
|
16 | 19 | class FakeMemory: |
@@ -129,18 +132,232 @@ def test_save_checkpoint_add_and_update(monkeypatch: pytest.MonkeyPatch) -> None |
129 | 132 | assert mem.added[0]["user_id"] == "u1" |
130 | 133 | assert mem.added[0]["agent_id"] is None |
131 | 134 |
|
132 | | - # update existing (uses delete+add, not update) |
133 | | - ok2, same_id = mgr.save( |
| 135 | + # update existing (uses add-first-then-delete) |
| 136 | + ok2, new_id2 = mgr.save( |
134 | 137 | checkpoint_id=new_id, |
135 | 138 | user_id="u1", |
136 | 139 | app_id=None, |
137 | 140 | checkpoint=UserCheckpoint(), |
138 | 141 | ) |
139 | 142 | assert ok2 is True |
140 | | - assert same_id == new_id |
141 | | - # save uses delete+add instead of update to avoid embedding |
| 143 | + assert new_id2 is not None |
| 144 | + # save adds new first, then deletes old |
142 | 145 | assert new_id in mem.deleted # Old checkpoint deleted |
143 | 146 | assert len(mem.added) == 2 # New checkpoint added |
144 | 147 |
|
145 | 148 |
|
146 | | - |
| 149 | +def test_save_add_before_delete_order() -> None: |
| 150 | + """Verify that save() adds the new checkpoint BEFORE deleting the old one.""" |
| 151 | + mem = FakeMemory() |
| 152 | + mgr = SyncCheckpointManager(mem) |
| 153 | + |
| 154 | + # Create initial checkpoint |
| 155 | + ok, initial_id = mgr.save( |
| 156 | + checkpoint_id=None, |
| 157 | + user_id="u1", |
| 158 | + app_id=None, |
| 159 | + checkpoint=UserCheckpoint(), |
| 160 | + ) |
| 161 | + assert ok and initial_id |
| 162 | + |
| 163 | + # Track operation order |
| 164 | + operations: list[str] = [] |
| 165 | + original_add = mem.add |
| 166 | + original_delete = mem.delete |
| 167 | + |
| 168 | + def tracking_add(text: str, **kwargs: Any) -> dict[str, Any]: |
| 169 | + operations.append("add") |
| 170 | + return original_add(text, **kwargs) |
| 171 | + |
| 172 | + def tracking_delete(memory_id: str) -> dict[str, Any]: |
| 173 | + operations.append("delete") |
| 174 | + return original_delete(memory_id) |
| 175 | + |
| 176 | + mem.add = tracking_add # type: ignore[assignment] |
| 177 | + mem.delete = tracking_delete # type: ignore[assignment] |
| 178 | + |
| 179 | + # Update checkpoint |
| 180 | + ok2, new_id = mgr.save( |
| 181 | + checkpoint_id=initial_id, |
| 182 | + user_id="u1", |
| 183 | + app_id=None, |
| 184 | + checkpoint=UserCheckpoint(), |
| 185 | + ) |
| 186 | + assert ok2 is True |
| 187 | + # add must come before delete |
| 188 | + assert operations == ["add", "delete"] |
| 189 | + |
| 190 | + |
| 191 | +def test_save_keeps_old_checkpoint_on_add_failure() -> None: |
| 192 | + """If add() fails, old checkpoint should NOT be deleted.""" |
| 193 | + mem = FakeMemory() |
| 194 | + mgr = SyncCheckpointManager(mem) |
| 195 | + |
| 196 | + # Create initial checkpoint |
| 197 | + ok, initial_id = mgr.save( |
| 198 | + checkpoint_id=None, |
| 199 | + user_id="u1", |
| 200 | + app_id=None, |
| 201 | + checkpoint=UserCheckpoint(), |
| 202 | + ) |
| 203 | + assert ok and initial_id |
| 204 | + |
| 205 | + # Make add return no ID (simulating failure) |
| 206 | + mem.add = lambda text, **kwargs: {"results": []} # type: ignore[assignment] |
| 207 | + |
| 208 | + ok2, new_id = mgr.save( |
| 209 | + checkpoint_id=initial_id, |
| 210 | + user_id="u1", |
| 211 | + app_id=None, |
| 212 | + checkpoint=UserCheckpoint(), |
| 213 | + ) |
| 214 | + # add returned no ID, so old checkpoint must NOT be deleted |
| 215 | + assert initial_id not in mem.deleted |
| 216 | + |
| 217 | + |
| 218 | +class AsyncFakeMemory: |
| 219 | + """Async-compatible fake Mem0 memory for testing.""" |
| 220 | + |
| 221 | + def __init__(self) -> None: |
| 222 | + self._store: list[dict[str, Any]] = [] |
| 223 | + self.added: list[dict[str, Any]] = [] |
| 224 | + self.deleted: list[str] = [] |
| 225 | + |
| 226 | + def _match(self, md: dict[str, Any], filt: Any) -> bool: |
| 227 | + if not filt or not isinstance(filt, dict): |
| 228 | + return True |
| 229 | + # Leaf filter: {"key": value} or {"key": {"eq": value}} |
| 230 | + for k, v in filt.items(): |
| 231 | + if isinstance(v, dict) and "eq" in v: |
| 232 | + if md.get(k) != v.get("eq"): |
| 233 | + return False |
| 234 | + else: |
| 235 | + if md.get(k) != v: |
| 236 | + return False |
| 237 | + return True |
| 238 | + |
| 239 | + async def get_all(self, **kwargs: Any) -> dict[str, Any]: |
| 240 | + filt = kwargs.get("filters") |
| 241 | + out: list[dict[str, Any]] = [] |
| 242 | + for item in self._store: |
| 243 | + md = item.get("metadata") or {} |
| 244 | + if self._match(md, filt): |
| 245 | + out.append(item) |
| 246 | + return {"results": out} |
| 247 | + |
| 248 | + async def add(self, text: str, **kwargs: Any) -> dict[str, Any]: |
| 249 | + md = kwargs.get("metadata") or {} |
| 250 | + new_id = f"cp_{len(self._store)+1}" |
| 251 | + self._store.append({"id": new_id, "memory": text, "metadata": md}) |
| 252 | + self.added.append({"id": new_id, "text": text}) |
| 253 | + return {"results": [{"id": new_id, "event": "ADD"}]} |
| 254 | + |
| 255 | + async def delete(self, memory_id: str) -> dict[str, Any]: |
| 256 | + self.deleted.append(memory_id) |
| 257 | + self._store = [x for x in self._store if x.get("id") != memory_id] |
| 258 | + return {"message": "deleted"} |
| 259 | + |
| 260 | + |
| 261 | +def _run_async(coro: Any) -> Any: |
| 262 | + """Run a coroutine in a dedicated thread to avoid event loop conflicts with pytest-asyncio.""" |
| 263 | + result: list[Any] = [] |
| 264 | + exc: list[BaseException] = [] |
| 265 | + |
| 266 | + def _thread_target() -> None: |
| 267 | + # Clear any inherited running-loop state from the parent thread. |
| 268 | + # In pytest-asyncio AUTO mode the main thread runs inside an event |
| 269 | + # loop, and CPython's C-level _get_running_loop() can inherit that |
| 270 | + # state in child threads, causing run_until_complete() to raise |
| 271 | + # "Cannot run the event loop while another loop is running". |
| 272 | + asyncio.events._set_running_loop(None) # type: ignore[attr-defined] |
| 273 | + asyncio.set_event_loop(None) |
| 274 | + loop = asyncio.new_event_loop() |
| 275 | + asyncio.set_event_loop(loop) |
| 276 | + try: |
| 277 | + result.append(loop.run_until_complete(coro)) |
| 278 | + except BaseException as e: # noqa: BLE001 |
| 279 | + exc.append(e) |
| 280 | + finally: |
| 281 | + loop.close() |
| 282 | + asyncio.set_event_loop(None) |
| 283 | + |
| 284 | + t = threading.Thread(target=_thread_target) |
| 285 | + t.start() |
| 286 | + t.join() |
| 287 | + if exc: |
| 288 | + raise exc[0] |
| 289 | + return result[0] if result else None |
| 290 | + |
| 291 | + |
| 292 | +def test_async_load_restores_resume_fields() -> None: |
| 293 | + """AsyncCheckpointManager.load() must restore resume_* fields.""" |
| 294 | + |
| 295 | + async def _run() -> None: |
| 296 | + mem = AsyncFakeMemory() |
| 297 | + mgr = AsyncCheckpointManager(mem) |
| 298 | + |
| 299 | + cp = UserCheckpoint( |
| 300 | + conversations={ |
| 301 | + "conv1": ConversationCheckpoint( |
| 302 | + last_processed_message_id="msg-100", |
| 303 | + processed_range_start="2026-01-01T00:00:00Z", |
| 304 | + processed_range_end="2026-01-02T00:00:00Z", |
| 305 | + ), |
| 306 | + }, |
| 307 | + resume_conversation_cursor="conv1", |
| 308 | + resume_run_at="2026-04-01T12:00:00Z", |
| 309 | + resume_start_time="2026-03-01T00:00:00Z", |
| 310 | + ) |
| 311 | + |
| 312 | + ok, cp_id = await mgr.save( |
| 313 | + checkpoint_id=None, user_id="u1", app_id=None, checkpoint=cp |
| 314 | + ) |
| 315 | + assert ok and cp_id |
| 316 | + |
| 317 | + loaded_id, loaded_cp = await mgr.load(user_id="u1", app_id=None) |
| 318 | + assert loaded_id == cp_id |
| 319 | + assert loaded_cp is not None |
| 320 | + assert loaded_cp.resume_conversation_cursor == "conv1" |
| 321 | + assert loaded_cp.resume_run_at == "2026-04-01T12:00:00Z" |
| 322 | + assert loaded_cp.resume_start_time == "2026-03-01T00:00:00Z" |
| 323 | + assert "conv1" in loaded_cp.conversations |
| 324 | + assert loaded_cp.conversations["conv1"].last_processed_message_id == "msg-100" |
| 325 | + |
| 326 | + _run_async(_run()) |
| 327 | + |
| 328 | + |
| 329 | +def test_async_save_add_before_delete() -> None: |
| 330 | + """Async save() must add new checkpoint before deleting old.""" |
| 331 | + |
| 332 | + async def _run() -> None: |
| 333 | + mem = AsyncFakeMemory() |
| 334 | + mgr = AsyncCheckpointManager(mem) |
| 335 | + |
| 336 | + ok, initial_id = await mgr.save( |
| 337 | + checkpoint_id=None, user_id="u1", app_id=None, checkpoint=UserCheckpoint() |
| 338 | + ) |
| 339 | + assert ok and initial_id |
| 340 | + |
| 341 | + operations: list[str] = [] |
| 342 | + original_add = mem.add |
| 343 | + original_delete = mem.delete |
| 344 | + |
| 345 | + async def tracking_add(text: str, **kwargs: Any) -> dict[str, Any]: |
| 346 | + operations.append("add") |
| 347 | + return await original_add(text, **kwargs) |
| 348 | + |
| 349 | + async def tracking_delete(memory_id: str) -> dict[str, Any]: |
| 350 | + operations.append("delete") |
| 351 | + return await original_delete(memory_id) |
| 352 | + |
| 353 | + mem.add = tracking_add # type: ignore[assignment] |
| 354 | + mem.delete = tracking_delete # type: ignore[assignment] |
| 355 | + |
| 356 | + ok2, new_id = await mgr.save( |
| 357 | + checkpoint_id=initial_id, user_id="u1", app_id=None, checkpoint=UserCheckpoint() |
| 358 | + ) |
| 359 | + assert ok2 is True |
| 360 | + assert operations == ["add", "delete"] |
| 361 | + |
| 362 | + _run_async(_run()) |
| 363 | + |
0 commit comments