diff --git a/src/memu/app/memorize.py b/src/memu/app/memorize.py index 0f2a06fc..f0c6cc3d 100644 --- a/src/memu/app/memorize.py +++ b/src/memu/app/memorize.py @@ -37,6 +37,8 @@ logger = logging.getLogger(__name__) +UNCATEGORIZED_CATEGORY_NAME = "uncategorized" + if TYPE_CHECKING: from memu.app.service import Context from memu.app.settings import MemorizeConfig @@ -615,6 +617,10 @@ async def _persist_memory_items( # existing item continue mapped_cat_ids = self._map_category_names_to_ids(cat_names, ctx) + if not mapped_cat_ids and self.memorize_config.enable_uncategorized_fallback: + fallback_id = ctx.category_name_to_id.get(UNCATEGORIZED_CATEGORY_NAME) + if fallback_id is not None: + mapped_cat_ids = [fallback_id] for cid in mapped_cat_ids: rels.append(store.category_item_repo.link_item_category(item.id, cid, user_data=dict(user or {}))) # Store (item_id, summary) tuple for reference support @@ -650,19 +656,36 @@ async def _initialize_categories( ) -> None: if ctx.categories_ready: return - if not self.category_configs: + configs = list(self.category_configs) + if self.memorize_config.enable_uncategorized_fallback and not any( + cfg.name.lower() == UNCATEGORIZED_CATEGORY_NAME for cfg in configs + ): + configs.append( + CategoryConfig( + name=UNCATEGORIZED_CATEGORY_NAME, + description="Memory items that do not match any other configured category.", + ) + ) + if not configs: ctx.categories_ready = True return - cat_texts = [self._category_embedding_text(cfg) for cfg in self.category_configs] + cat_texts = [self._category_embedding_text(cfg) for cfg in configs] cat_vecs = await self._get_llm_client("embedding").embed(cat_texts) ctx.category_ids = [] ctx.category_name_to_id = {} - for cfg, vec in zip(self.category_configs, cat_vecs, strict=True): + for cfg, vec in zip(configs, cat_vecs, strict=True): name = cfg.name.strip() or "Untitled" description = cfg.description.strip() cat = store.memory_category_repo.get_or_create_category( name=name, description=description, embedding=vec, user_data=dict(user or {}) ) + # Seed a static summary for the fallback: _rank_categories_by_summary + # in retrieve.py filters on `cat.summary`, and we deliberately skip + # dynamic summary updates for this category (see _update_category_summaries). + # Must go through update_category so the value persists on SQL backends + # whose get_or_create_category returns a detached/copied instance. + if name.lower() == UNCATEGORIZED_CATEGORY_NAME and not cat.summary and description: + cat = store.memory_category_repo.update_category(category_id=cat.id, summary=description) ctx.category_ids.append(cat.id) ctx.category_name_to_id[name.lower()] = cat.id ctx.categories_ready = True @@ -1120,6 +1143,11 @@ async def _update_category_summaries( cat = store.memory_category_repo.categories.get(cid) if not cat or not memories: continue + if cat.name.lower() == UNCATEGORIZED_CATEGORY_NAME: + # Summaries over heterogeneous uncategorized items are incoherent and + # waste tokens; the category's static description embedding is enough + # for route_category to consider it when nothing else matches. + continue prompt = self._build_category_summary_prompt(category=cat, new_memories=memories) tasks.append(client.chat(prompt)) target_ids.append(cid) @@ -1259,7 +1287,20 @@ def _parse_memory_type_response(self, raw: str) -> list[dict[str, Any]]: def _find_xml_boundaries(self, raw: str) -> tuple[int, int, str] | None: """Find the start index, end index, and closing tag for XML root element.""" - root_tags = ["item", "profile", "behaviors", "events", "knowledge", "skills"] + root_tags = [ + "item", + "profile", + "profiles", + "event", + "events", + "knowledge", + "behavior", + "behaviors", + "skill", + "skills", + "tool", + "tools", + ] for tag in root_tags: opening = f"<{tag}>" closing = f"" @@ -1283,7 +1324,9 @@ def _parse_memory_element(self, memory_elem: Element) -> dict[str, Any] | None: categories = [cat_elem.text.strip() for cat_elem in categories_elem.findall("category") if cat_elem.text] memory_dict["categories"] = categories - if memory_dict.get("content") and memory_dict.get("categories"): + # `categories` may be empty per the memory type prompts. + if memory_dict.get("content"): + memory_dict.setdefault("categories", []) return memory_dict return None @@ -1292,7 +1335,7 @@ def _parse_memory_type_response_xml(self, raw: str) -> list[dict[str, Any]]: Parse XML memory extraction output into a list of memory items. Expected XML format (root tag varies by memory type): - + ... diff --git a/src/memu/app/settings.py b/src/memu/app/settings.py index adcb4f16..e27f4596 100644 --- a/src/memu/app/settings.py +++ b/src/memu/app/settings.py @@ -240,6 +240,14 @@ class MemorizeConfig(BaseModel): default=False, description="Enable reinforcement tracking for memory items.", ) + enable_uncategorized_fallback: bool = Field( + default=True, + description=( + "Auto-create an 'uncategorized' category and link memory items whose " + "extracted categories match none of the configured ones, so LLM-mode " + "retrieval (which joins items via category relations) can still reach them." + ), + ) class PatchConfig(BaseModel): diff --git a/tests/test_uncategorized_fallback.py b/tests/test_uncategorized_fallback.py new file mode 100644 index 00000000..4abff24c --- /dev/null +++ b/tests/test_uncategorized_fallback.py @@ -0,0 +1,200 @@ +""" +Tests for the uncategorized-fallback feature in MemorizeMixin. + +When the LLM extracts a memory item that maps to no configured category, +it is linked to an auto-created "uncategorized" category so it remains +reachable from LLM-mode retrieval (which joins items via category relations). +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from memu.app.memorize import UNCATEGORIZED_CATEGORY_NAME, MemorizeMixin +from memu.app.service import Context +from memu.app.settings import CategoryConfig, DatabaseConfig, DefaultUserModel, MemorizeConfig +from memu.database.factory import build_database +from memu.database.interfaces import Database + + +class FakeLLMClient: + """Deterministic stand-in for the embedding and chat client.""" + + chat_model = "fake-chat" + embed_model = "fake-embed" + + async def embed(self, inputs: list[str]) -> list[list[float]]: + return [[float(len(s) % 5), 0.0, 0.0] for s in inputs] + + async def chat(self, *_: Any, **__: Any) -> str: + return "fake summary" + + +def _build_mixin( + *, + enable_fallback: bool = True, + configured_categories: list[CategoryConfig] | None = None, +) -> tuple[MemorizeMixin, Context, Database]: + cfg = MemorizeConfig( + memory_categories=configured_categories or [], + enable_uncategorized_fallback=enable_fallback, + ) + mixin = MemorizeMixin.__new__(MemorizeMixin) + mixin.memorize_config = cfg + mixin.category_configs = list(cfg.memory_categories) + mixin._get_llm_client = lambda profile=None, step_context=None: FakeLLMClient() # type: ignore[method-assign] + + db = build_database( + config=DatabaseConfig.model_validate({"metadata_store": {"provider": "inmemory"}}), + user_model=DefaultUserModel, + ) + return mixin, Context(), db + + +class TestConfigFlag: + def test_defaults_enabled(self): + """New deployments get the fallback automatically.""" + assert MemorizeConfig().enable_uncategorized_fallback is True + + def test_can_be_disabled(self): + cfg = MemorizeConfig(enable_uncategorized_fallback=False) + assert cfg.enable_uncategorized_fallback is False + + +class TestInitializeCategories: + @pytest.mark.asyncio + async def test_creates_uncategorized_when_no_user_categories(self): + """Even with empty memory_categories, the fallback is created.""" + mixin, ctx, db = _build_mixin() + await mixin._initialize_categories(ctx, db) + assert UNCATEGORIZED_CATEGORY_NAME in ctx.category_name_to_id + assert ctx.categories_ready is True + cat_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME] + assert cat_id in db.memory_category_repo.categories + + @pytest.mark.asyncio + async def test_appends_uncategorized_alongside_user_categories(self): + mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="d")]) + await mixin._initialize_categories(ctx, db) + assert "habits" in ctx.category_name_to_id + assert UNCATEGORIZED_CATEGORY_NAME in ctx.category_name_to_id + + @pytest.mark.asyncio + async def test_skipped_when_disabled(self): + mixin, ctx, db = _build_mixin(enable_fallback=False) + await mixin._initialize_categories(ctx, db) + assert UNCATEGORIZED_CATEGORY_NAME not in ctx.category_name_to_id + + @pytest.mark.asyncio + async def test_fallback_category_has_seeded_summary(self): + """route_category filters on cat.summary; only the fallback gets one at init.""" + mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="user habits")]) + await mixin._initialize_categories(ctx, db) + + fallback_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME] + habits_id = ctx.category_name_to_id["habits"] + assert db.memory_category_repo.categories[fallback_id].summary + assert not db.memory_category_repo.categories[habits_id].summary + + @pytest.mark.asyncio + async def test_not_duplicated_if_user_already_configured_it(self): + """User-defined 'uncategorized' takes precedence; no second copy is added.""" + mixin, ctx, db = _build_mixin( + configured_categories=[CategoryConfig(name=UNCATEGORIZED_CATEGORY_NAME, description="user-defined")] + ) + await mixin._initialize_categories(ctx, db) + assert ( + sum(1 for c in db.memory_category_repo.categories.values() if c.name.lower() == UNCATEGORIZED_CATEGORY_NAME) + == 1 + ) + + +class TestPersistMemoryItems: + @pytest.mark.asyncio + async def test_links_uncategorized_item_to_fallback_category(self): + """An item with no matching categories ends up linked to the fallback.""" + mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="d")]) + await mixin._initialize_categories(ctx, db) + + items, rels, updates = await mixin._persist_memory_items( + resource_id="res-1", + structured_entries=[("profile", "User uses dark mode.", [])], + ctx=ctx, + store=db, + embed_client=FakeLLMClient(), + user={"user_id": "u1"}, + ) + + assert len(items) == 1 + assert len(rels) == 1 + fallback_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME] + assert rels[0].category_id == fallback_id + assert fallback_id in updates + + @pytest.mark.asyncio + async def test_keeps_explicit_category_when_provided(self): + """If the LLM returns a matching category name, no fallback is applied.""" + mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="d")]) + await mixin._initialize_categories(ctx, db) + + _, rels, _ = await mixin._persist_memory_items( + resource_id="res-1", + structured_entries=[("profile", "User journals daily.", ["habits"])], + ctx=ctx, + store=db, + embed_client=FakeLLMClient(), + user={"user_id": "u1"}, + ) + + habits_id = ctx.category_name_to_id["habits"] + assert [r.category_id for r in rels] == [habits_id] + + @pytest.mark.asyncio + async def test_no_fallback_when_disabled(self): + """With the flag off, uncategorized items get no category links at all.""" + mixin, ctx, db = _build_mixin( + enable_fallback=False, + configured_categories=[CategoryConfig(name="habits", description="d")], + ) + await mixin._initialize_categories(ctx, db) + + _, rels, updates = await mixin._persist_memory_items( + resource_id="res-1", + structured_entries=[("profile", "User likes oolong tea.", [])], + ctx=ctx, + store=db, + embed_client=FakeLLMClient(), + user={"user_id": "u1"}, + ) + + assert rels == [] + assert updates == {} + + +class TestUpdateCategorySummaries: + @pytest.mark.asyncio + async def test_skips_summary_for_uncategorized_category(self): + """Heterogeneous uncategorized items should not trigger a summary LLM call.""" + mixin, ctx, db = _build_mixin() + await mixin._initialize_categories(ctx, db) + fallback_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME] + + call_count = 0 + + class CountingClient(FakeLLMClient): + async def chat(self, *_: Any, **__: Any) -> str: + nonlocal call_count + call_count += 1 + return "should not happen" + + result = await mixin._update_category_summaries( + {fallback_id: [("item-1", "User uses dark mode.")]}, + ctx, + db, + llm_client=CountingClient(), + ) + + assert call_count == 0 + assert result == {} diff --git a/tests/test_xml_parser.py b/tests/test_xml_parser.py new file mode 100644 index 00000000..1b712b98 --- /dev/null +++ b/tests/test_xml_parser.py @@ -0,0 +1,116 @@ +""" +Tests for XML memory extraction parser in MemorizeMixin. + +Tests cover: +1. Memory items with empty `categories` are preserved (prompts allow empty). +2. Root-tag whitelist accepts every MemoryType value (singular and plural). +""" + +from __future__ import annotations + +import defusedxml.ElementTree as ET + +from memu.app.memorize import MemorizeMixin + + +def _mixin() -> MemorizeMixin: + return MemorizeMixin.__new__(MemorizeMixin) + + +class TestParseMemoryElement: + """Tests for _parse_memory_element.""" + + def test_keeps_item_with_empty_categories(self): + """Empty tag should not cause the item to be dropped.""" + elem = ET.fromstring("The user prefers dark mode.") + assert _mixin()._parse_memory_element(elem) == { + "content": "The user prefers dark mode.", + "categories": [], + } + + def test_keeps_item_without_categories_tag(self): + """Missing tag should not cause the item to be dropped.""" + elem = ET.fromstring("The user lives in Beijing.") + assert _mixin()._parse_memory_element(elem) == { + "content": "The user lives in Beijing.", + "categories": [], + } + + def test_drops_item_without_content(self): + """Items without are still dropped.""" + elem = ET.fromstring("x") + assert _mixin()._parse_memory_element(elem) is None + + def test_keeps_non_empty_categories(self): + """Non-empty is preserved verbatim.""" + elem = ET.fromstring( + "" + "The user drinks coffee daily." + "habitspreferences" + "" + ) + parsed = _mixin()._parse_memory_element(elem) + assert parsed is not None + assert parsed["categories"] == ["habits", "preferences"] + + +class TestFindXmlBoundaries: + """Tests for _find_xml_boundaries root-tag detection.""" + + def test_accepts_singular_event(self): + raw = "x" + boundaries = _mixin()._find_xml_boundaries(raw) + assert boundaries is not None + assert boundaries[2] == "" + + def test_accepts_singular_behavior(self): + raw = "x" + assert _mixin()._find_xml_boundaries(raw) is not None + + def test_accepts_singular_skill(self): + raw = "x" + assert _mixin()._find_xml_boundaries(raw) is not None + + def test_accepts_tool(self): + """`tool` is a valid MemoryType (database/models.py) but was missing.""" + raw = "x" + assert _mixin()._find_xml_boundaries(raw) is not None + + def test_accepts_legacy_plural_tags(self): + """Original whitelist must keep working.""" + for tag in ("profile", "behaviors", "events", "knowledge", "skills"): + raw = f"<{tag}>x" + assert _mixin()._find_xml_boundaries(raw) is not None, tag + + def test_rejects_unknown_root(self): + raw = "x" + assert _mixin()._find_xml_boundaries(raw) is None + + +class TestParseMemoryTypeResponseXml: + """End-to-end tests via _parse_memory_type_response_xml.""" + + def test_singular_root_with_empty_categories(self): + """LLM returns root with one item lacking categories.""" + raw = ( + "" + "The user attended a meetup in Beijing yesterday." + "" + "" + ) + items = _mixin()._parse_memory_type_response_xml(raw) + assert len(items) == 1 + assert items[0]["categories"] == [] + + def test_tool_root_with_mixed_categories(self): + """Two memories under , one categorised, one not.""" + raw = ( + "" + "Calculator added 2 and 2." + "math" + "Weather queried for Beijing." + "" + "" + ) + items = _mixin()._parse_memory_type_response_xml(raw) + assert [it["categories"] for it in items] == [["math"], []]