Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.

Commit 9aaf5f8

Browse files
Invalidate profile cache after memory ingest
1 parent 5b04cc7 commit 9aaf5f8

4 files changed

Lines changed: 71 additions & 1 deletion

File tree

src/api/routes/memory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ async def _run_ingest_payload(
193193
image_url=payload.get("image_url", ""),
194194
effort_level=payload.get("effort_level", "low"),
195195
)
196+
_invalidate_profile_cache(user_id)
196197
data = IngestResponse(
197198
model=_model_name(pipeline.model),
198199
classification=_safe_classifications(result),
@@ -764,6 +765,13 @@ def _safe_classifications(result: Dict[str, Any]) -> list:
764765
return []
765766

766767

768+
def _invalidate_profile_cache(user_id: str) -> None:
769+
try:
770+
get_retrieval_pipeline().invalidate_profile_cache(user_id)
771+
except Exception as exc:
772+
logger.warning("Failed to invalidate profile cache for user=%s: %s", user_id, exc)
773+
774+
767775
async def _read_user_job(job_id: str, user_id: str) -> Dict[str, Any] | None:
768776
job = await asyncio.to_thread(get_default_job_store().get, job_id)
769777
if not job:

src/pipelines/retrieval.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,14 @@ async def _get_profile_catalog(self, user_id: str):
697697
self._trim_cache(self._profile_catalog_cache, _PROFILE_CATALOG_CACHE_LIMIT)
698698
return catalog, results
699699

700+
def invalidate_profile_cache(self, user_id: str) -> None:
701+
"""Clear cached profile records after a user's memories are ingested."""
702+
703+
self._profile_catalog_cache.pop(user_id, None)
704+
for key in list(self._retrieval_plan_cache):
705+
if key[0] == user_id:
706+
self._retrieval_plan_cache.pop(key, None)
707+
700708
def _fetch_profile_catalog(self, user_id: str):
701709
"""Fetch all profile entries for a user.
702710

tests/api/test_memory_search_routes.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class FakeSearchPipeline:
1818

1919
def __init__(self) -> None:
2020
self.answer_calls = 0
21+
self.invalidated_users: list[str] = []
2122
self.search_calls: list[dict[str, object]] = []
2223
self.latencies: dict[str, list[float]] = {}
2324

@@ -69,14 +70,24 @@ def get_latency_snapshot(self):
6970
for mode, samples in self.latencies.items()
7071
}
7172

73+
def invalidate_profile_cache(self, user_id: str) -> None:
74+
self.invalidated_users.append(user_id)
75+
76+
77+
class FakeIngestPipeline:
78+
model = SimpleNamespace(model="fake-ingest")
79+
80+
async def run(self, **kwargs):
81+
return {"classification_result": SimpleNamespace(classifications=[])}
82+
7283

7384
@pytest.fixture
7485
def memory_search_app(monkeypatch):
7586
pipeline = FakeSearchPipeline()
7687
monkeypatch.setattr(deps.settings, "api_keys", ["test-static-key"], raising=False)
7788
deps._init_error = None
7889
deps._pipelines_ready.set()
79-
deps.set_pipelines(SimpleNamespace(), pipeline)
90+
deps.set_pipelines(FakeIngestPipeline(), pipeline)
8091

8192
app = FastAPI()
8293
app.add_middleware(RequestContextMiddleware)
@@ -152,3 +163,19 @@ def test_memory_search_route_accepts_code_domain(memory_search_app):
152163
assert payload["data"]["results"][0]["domain"] == "code"
153164
assert payload["data"]["results"][0]["metadata"]["target_file"] == "src/retry.py"
154165
assert pipeline.search_calls[0]["domains"] == ["code"]
166+
167+
168+
def test_memory_ingest_invalidates_retrieval_profile_cache(memory_search_app):
169+
app, pipeline = memory_search_app
170+
response = TestClient(app).post(
171+
"/v1/memory/ingest",
172+
headers={"Authorization": "Bearer test-static-key"},
173+
json={
174+
"user_query": "Remember that I work at XMem",
175+
"agent_response": "Acknowledged.",
176+
"user_id": "ignored-by-auth",
177+
},
178+
)
179+
180+
assert response.status_code == 200
181+
assert pipeline.invalidated_users == ["Static Key User"]

tests/integration/test_retrieval_pipeline.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,33 @@ def test_retrieval_plan_cache_evicts_oldest_entry(vector_store, neo4j_client):
332332
assert pipeline._get_cached_retrieval_plan(first_key) is None
333333

334334

335+
def test_invalidate_profile_cache_clears_user_profile_and_plan_entries(
336+
vector_store, neo4j_client
337+
):
338+
model = FakeChatModel()
339+
pipeline = RetrievalPipeline(
340+
model=model, vector_store=vector_store, neo4j_client=neo4j_client
341+
)
342+
343+
pipeline._profile_catalog_cache["alice"] = (999999999.0, [], [])
344+
pipeline._profile_catalog_cache["bob"] = (999999999.0, [], [])
345+
pipeline._cache_retrieval_plan(
346+
("alice", "where do I work?", 5, "catalog-a"),
347+
FakeLLMResponse("alice-plan"),
348+
)
349+
pipeline._cache_retrieval_plan(
350+
("bob", "where do I work?", 5, "catalog-b"),
351+
FakeLLMResponse("bob-plan"),
352+
)
353+
354+
pipeline.invalidate_profile_cache("alice")
355+
356+
assert "alice" not in pipeline._profile_catalog_cache
357+
assert "bob" in pipeline._profile_catalog_cache
358+
assert not any(key[0] == "alice" for key in pipeline._retrieval_plan_cache)
359+
assert any(key[0] == "bob" for key in pipeline._retrieval_plan_cache)
360+
361+
335362
@pytest.mark.asyncio
336363
async def test_answer_from_sources_skips_tool_selection(vector_store, neo4j_client):
337364
model = FakeChatModel(responses=["Alice works at XMem."])

0 commit comments

Comments
 (0)