Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions src/memu/app/memorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

logger = logging.getLogger(__name__)

UNCATEGORIZED_CATEGORY_NAME = "uncategorized"

if TYPE_CHECKING:
from memu.app.service import Context
from memu.app.settings import MemorizeConfig
Expand Down Expand Up @@ -615,6 +617,10 @@ async def _persist_memory_items(
# existing item
continue
mapped_cat_ids = self._map_category_names_to_ids(cat_names, ctx)
if not mapped_cat_ids and self.memorize_config.enable_uncategorized_fallback:
fallback_id = ctx.category_name_to_id.get(UNCATEGORIZED_CATEGORY_NAME)
if fallback_id is not None:
mapped_cat_ids = [fallback_id]
for cid in mapped_cat_ids:
rels.append(store.category_item_repo.link_item_category(item.id, cid, user_data=dict(user or {})))
# Store (item_id, summary) tuple for reference support
Expand Down Expand Up @@ -650,19 +656,36 @@ async def _initialize_categories(
) -> None:
if ctx.categories_ready:
return
if not self.category_configs:
configs = list(self.category_configs)
if self.memorize_config.enable_uncategorized_fallback and not any(
cfg.name.lower() == UNCATEGORIZED_CATEGORY_NAME for cfg in configs
):
configs.append(
CategoryConfig(
name=UNCATEGORIZED_CATEGORY_NAME,
description="Memory items that do not match any other configured category.",
)
)
if not configs:
ctx.categories_ready = True
return
cat_texts = [self._category_embedding_text(cfg) for cfg in self.category_configs]
cat_texts = [self._category_embedding_text(cfg) for cfg in configs]
cat_vecs = await self._get_llm_client("embedding").embed(cat_texts)
ctx.category_ids = []
ctx.category_name_to_id = {}
for cfg, vec in zip(self.category_configs, cat_vecs, strict=True):
for cfg, vec in zip(configs, cat_vecs, strict=True):
name = cfg.name.strip() or "Untitled"
description = cfg.description.strip()
cat = store.memory_category_repo.get_or_create_category(
name=name, description=description, embedding=vec, user_data=dict(user or {})
)
# Seed a static summary for the fallback: _rank_categories_by_summary
# in retrieve.py filters on `cat.summary`, and we deliberately skip
# dynamic summary updates for this category (see _update_category_summaries).
# Must go through update_category so the value persists on SQL backends
# whose get_or_create_category returns a detached/copied instance.
if name.lower() == UNCATEGORIZED_CATEGORY_NAME and not cat.summary and description:
cat = store.memory_category_repo.update_category(category_id=cat.id, summary=description)
ctx.category_ids.append(cat.id)
ctx.category_name_to_id[name.lower()] = cat.id
ctx.categories_ready = True
Expand Down Expand Up @@ -1120,6 +1143,11 @@ async def _update_category_summaries(
cat = store.memory_category_repo.categories.get(cid)
if not cat or not memories:
continue
if cat.name.lower() == UNCATEGORIZED_CATEGORY_NAME:
# Summaries over heterogeneous uncategorized items are incoherent and
# waste tokens; the category's static description embedding is enough
# for route_category to consider it when nothing else matches.
continue
prompt = self._build_category_summary_prompt(category=cat, new_memories=memories)
tasks.append(client.chat(prompt))
target_ids.append(cid)
Expand Down Expand Up @@ -1259,7 +1287,20 @@ def _parse_memory_type_response(self, raw: str) -> list[dict[str, Any]]:

def _find_xml_boundaries(self, raw: str) -> tuple[int, int, str] | None:
"""Find the start index, end index, and closing tag for XML root element."""
root_tags = ["item", "profile", "behaviors", "events", "knowledge", "skills"]
root_tags = [
"item",
"profile",
"profiles",
"event",
"events",
"knowledge",
"behavior",
"behaviors",
"skill",
"skills",
"tool",
"tools",
]
for tag in root_tags:
opening = f"<{tag}>"
closing = f"</{tag}>"
Expand All @@ -1283,7 +1324,9 @@ def _parse_memory_element(self, memory_elem: Element) -> dict[str, Any] | None:
categories = [cat_elem.text.strip() for cat_elem in categories_elem.findall("category") if cat_elem.text]
memory_dict["categories"] = categories

if memory_dict.get("content") and memory_dict.get("categories"):
# `categories` may be empty per the memory type prompts.
if memory_dict.get("content"):
memory_dict.setdefault("categories", [])
return memory_dict
return None

Expand All @@ -1292,7 +1335,7 @@ def _parse_memory_type_response_xml(self, raw: str) -> list[dict[str, Any]]:
Parse XML memory extraction output into a list of memory items.

Expected XML format (root tag varies by memory type):
<profile|behaviors|events|knowledge|skills>
<item|profile|event[s]|knowledge|behavior[s]|skill[s]|tool[s]>
<memory>
<content>...</content>
<categories>
Expand Down
8 changes: 8 additions & 0 deletions src/memu/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ class MemorizeConfig(BaseModel):
default=False,
description="Enable reinforcement tracking for memory items.",
)
enable_uncategorized_fallback: bool = Field(
default=True,
description=(
"Auto-create an 'uncategorized' category and link memory items whose "
"extracted categories match none of the configured ones, so LLM-mode "
"retrieval (which joins items via category relations) can still reach them."
),
)


class PatchConfig(BaseModel):
Expand Down
200 changes: 200 additions & 0 deletions tests/test_uncategorized_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
"""
Tests for the uncategorized-fallback feature in MemorizeMixin.

When the LLM extracts a memory item that maps to no configured category,
it is linked to an auto-created "uncategorized" category so it remains
reachable from LLM-mode retrieval (which joins items via category relations).
"""

from __future__ import annotations

from typing import Any

import pytest

from memu.app.memorize import UNCATEGORIZED_CATEGORY_NAME, MemorizeMixin
from memu.app.service import Context
from memu.app.settings import CategoryConfig, DatabaseConfig, DefaultUserModel, MemorizeConfig
from memu.database.factory import build_database
from memu.database.interfaces import Database


class FakeLLMClient:
"""Deterministic stand-in for the embedding and chat client."""

chat_model = "fake-chat"
embed_model = "fake-embed"

async def embed(self, inputs: list[str]) -> list[list[float]]:
return [[float(len(s) % 5), 0.0, 0.0] for s in inputs]

async def chat(self, *_: Any, **__: Any) -> str:
return "fake summary"


def _build_mixin(
*,
enable_fallback: bool = True,
configured_categories: list[CategoryConfig] | None = None,
) -> tuple[MemorizeMixin, Context, Database]:
cfg = MemorizeConfig(
memory_categories=configured_categories or [],
enable_uncategorized_fallback=enable_fallback,
)
mixin = MemorizeMixin.__new__(MemorizeMixin)
mixin.memorize_config = cfg
mixin.category_configs = list(cfg.memory_categories)
mixin._get_llm_client = lambda profile=None, step_context=None: FakeLLMClient() # type: ignore[method-assign]

db = build_database(
config=DatabaseConfig.model_validate({"metadata_store": {"provider": "inmemory"}}),
user_model=DefaultUserModel,
)
return mixin, Context(), db


class TestConfigFlag:
def test_defaults_enabled(self):
"""New deployments get the fallback automatically."""
assert MemorizeConfig().enable_uncategorized_fallback is True

def test_can_be_disabled(self):
cfg = MemorizeConfig(enable_uncategorized_fallback=False)
assert cfg.enable_uncategorized_fallback is False


class TestInitializeCategories:
@pytest.mark.asyncio
async def test_creates_uncategorized_when_no_user_categories(self):
"""Even with empty memory_categories, the fallback is created."""
mixin, ctx, db = _build_mixin()
await mixin._initialize_categories(ctx, db)
assert UNCATEGORIZED_CATEGORY_NAME in ctx.category_name_to_id
assert ctx.categories_ready is True
cat_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME]
assert cat_id in db.memory_category_repo.categories

@pytest.mark.asyncio
async def test_appends_uncategorized_alongside_user_categories(self):
mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="d")])
await mixin._initialize_categories(ctx, db)
assert "habits" in ctx.category_name_to_id
assert UNCATEGORIZED_CATEGORY_NAME in ctx.category_name_to_id

@pytest.mark.asyncio
async def test_skipped_when_disabled(self):
mixin, ctx, db = _build_mixin(enable_fallback=False)
await mixin._initialize_categories(ctx, db)
assert UNCATEGORIZED_CATEGORY_NAME not in ctx.category_name_to_id

@pytest.mark.asyncio
async def test_fallback_category_has_seeded_summary(self):
"""route_category filters on cat.summary; only the fallback gets one at init."""
mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="user habits")])
await mixin._initialize_categories(ctx, db)

fallback_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME]
habits_id = ctx.category_name_to_id["habits"]
assert db.memory_category_repo.categories[fallback_id].summary
assert not db.memory_category_repo.categories[habits_id].summary

@pytest.mark.asyncio
async def test_not_duplicated_if_user_already_configured_it(self):
"""User-defined 'uncategorized' takes precedence; no second copy is added."""
mixin, ctx, db = _build_mixin(
configured_categories=[CategoryConfig(name=UNCATEGORIZED_CATEGORY_NAME, description="user-defined")]
)
await mixin._initialize_categories(ctx, db)
assert (
sum(1 for c in db.memory_category_repo.categories.values() if c.name.lower() == UNCATEGORIZED_CATEGORY_NAME)
== 1
)


class TestPersistMemoryItems:
@pytest.mark.asyncio
async def test_links_uncategorized_item_to_fallback_category(self):
"""An item with no matching categories ends up linked to the fallback."""
mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="d")])
await mixin._initialize_categories(ctx, db)

items, rels, updates = await mixin._persist_memory_items(
resource_id="res-1",
structured_entries=[("profile", "User uses dark mode.", [])],
ctx=ctx,
store=db,
embed_client=FakeLLMClient(),
user={"user_id": "u1"},
)

assert len(items) == 1
assert len(rels) == 1
fallback_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME]
assert rels[0].category_id == fallback_id
assert fallback_id in updates

@pytest.mark.asyncio
async def test_keeps_explicit_category_when_provided(self):
"""If the LLM returns a matching category name, no fallback is applied."""
mixin, ctx, db = _build_mixin(configured_categories=[CategoryConfig(name="habits", description="d")])
await mixin._initialize_categories(ctx, db)

_, rels, _ = await mixin._persist_memory_items(
resource_id="res-1",
structured_entries=[("profile", "User journals daily.", ["habits"])],
ctx=ctx,
store=db,
embed_client=FakeLLMClient(),
user={"user_id": "u1"},
)

habits_id = ctx.category_name_to_id["habits"]
assert [r.category_id for r in rels] == [habits_id]

@pytest.mark.asyncio
async def test_no_fallback_when_disabled(self):
"""With the flag off, uncategorized items get no category links at all."""
mixin, ctx, db = _build_mixin(
enable_fallback=False,
configured_categories=[CategoryConfig(name="habits", description="d")],
)
await mixin._initialize_categories(ctx, db)

_, rels, updates = await mixin._persist_memory_items(
resource_id="res-1",
structured_entries=[("profile", "User likes oolong tea.", [])],
ctx=ctx,
store=db,
embed_client=FakeLLMClient(),
user={"user_id": "u1"},
)

assert rels == []
assert updates == {}


class TestUpdateCategorySummaries:
@pytest.mark.asyncio
async def test_skips_summary_for_uncategorized_category(self):
"""Heterogeneous uncategorized items should not trigger a summary LLM call."""
mixin, ctx, db = _build_mixin()
await mixin._initialize_categories(ctx, db)
fallback_id = ctx.category_name_to_id[UNCATEGORIZED_CATEGORY_NAME]

call_count = 0

class CountingClient(FakeLLMClient):
async def chat(self, *_: Any, **__: Any) -> str:
nonlocal call_count
call_count += 1
return "should not happen"

result = await mixin._update_category_summaries(
{fallback_id: [("item-1", "User uses dark mode.")]},
ctx,
db,
llm_client=CountingClient(),
)

assert call_count == 0
assert result == {}
Loading
Loading