|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import asyncio |
| 4 | +from unittest.mock import MagicMock |
| 5 | + |
| 6 | +from provider.mem0ai import Mem0Provider |
| 7 | + |
| 8 | + |
| 9 | +def test_validate_credentials_async_mode_uses_async_client_search(monkeypatch) -> None: |
| 10 | + import provider.mem0ai as provider_mod |
| 11 | + |
| 12 | + captured: dict[str, object] = {} |
| 13 | + fake_loop = object() |
| 14 | + fake_future = MagicMock() |
| 15 | + fake_future.result.return_value = {"results": []} |
| 16 | + |
| 17 | + class FakeClient: |
| 18 | + def ensure_bg_loop(self) -> object: |
| 19 | + captured["ensure_bg_loop_called"] = True |
| 20 | + return fake_loop |
| 21 | + |
| 22 | + async def search(self, payload: dict[str, object], timeout_s: int) -> dict[str, object]: |
| 23 | + captured["search_payload"] = payload |
| 24 | + captured["search_timeout"] = timeout_s |
| 25 | + return {"results": []} |
| 26 | + |
| 27 | + def _fake_run_coroutine_threadsafe(coro, loop): # noqa: ANN001 |
| 28 | + assert asyncio.iscoroutine(coro) |
| 29 | + captured["loop"] = loop |
| 30 | + coro.close() |
| 31 | + return fake_future |
| 32 | + |
| 33 | + monkeypatch.setattr(provider_mod, "get_async_client", lambda _credentials: FakeClient()) |
| 34 | + monkeypatch.setattr( |
| 35 | + provider_mod.asyncio, |
| 36 | + "run_coroutine_threadsafe", |
| 37 | + _fake_run_coroutine_threadsafe, |
| 38 | + ) |
| 39 | + |
| 40 | + provider = object.__new__(Mem0Provider) |
| 41 | + provider._validate_credentials({"async_mode": True, "log_level": "INFO"}) |
| 42 | + |
| 43 | + assert captured["ensure_bg_loop_called"] is True |
| 44 | + assert captured["loop"] is fake_loop |
| 45 | + assert fake_future.result.call_count == 1 |
0 commit comments