diff --git a/src/memu/app/crud.py b/src/memu/app/crud.py index 50d63d4c..89b76c80 100644 --- a/src/memu/app/crud.py +++ b/src/memu/app/crud.py @@ -650,7 +650,7 @@ def _map_category_names_to_ids(self, names: list[str], ctx: Context) -> list[str async def _patch_category_summaries( self, - updates: dict[str, list[str]], + updates: dict[str, tuple[Any, Any]], ctx: Context, store: Database, llm_client: Any | None = None, diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 3031c56b..888d8b43 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -68,11 +68,17 @@ async def main(): # RAG-based retrieval service.retrieve_config.method = "rag" result_rag = await service.retrieve(queries=queries, where={"user_id": "123"}) + assert isinstance(result_rag, dict) + assert "items" in result_rag + assert "categories" in result_rag _print_results("RAG", result_rag) # LLM-based retrieval service.retrieve_config.method = "llm" result_llm = await service.retrieve(queries=queries, where={"user_id": "123"}) + assert isinstance(result_llm, dict) + assert "items" in result_llm + assert "categories" in result_llm _print_results("LLM", result_llm) print("\n[SQLITE] Test completed!")