From c3dcb24be77fe45affb931c14f506da93fbdb4b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 26 Mar 2026 19:20:04 +0800 Subject: [PATCH 1/6] feat: neo4j example now read memory from env --- examples/basic_modules/neo4j_example.py | 83 ++++++++++++++----------- src/memos/graph_dbs/neo4j.py | 4 +- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/examples/basic_modules/neo4j_example.py b/examples/basic_modules/neo4j_example.py index e1c0df317..2a5e88749 100644 --- a/examples/basic_modules/neo4j_example.py +++ b/examples/basic_modules/neo4j_example.py @@ -2,6 +2,8 @@ from datetime import datetime +from dotenv import load_dotenv + from memos.configs.embedder import EmbedderConfigFactory from memos.configs.graph_db import GraphDBConfigFactory from memos.embedders.factory import EmbedderFactory @@ -9,14 +11,27 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +load_dotenv() + +NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687") +NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") +NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "12345678") +NEO4J_DB_NAME = os.getenv("NEO4J_DB_NAME", "neo4j") +EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIMENSION", "3072")) + +QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") +QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) + embedder_config = EmbedderConfigFactory.model_validate( { - "backend": "universal_api", + "backend": os.getenv("MOS_EMBEDDER_BACKEND", "universal_api"), "config": { - "provider": "openai", - "api_key": os.getenv("OPENAI_API_KEY", "sk-xxxxx"), - "model_name_or_path": "text-embedding-3-large", - "base_url": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), + "provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"), + "api_key": os.getenv("MOS_EMBEDDER_API_KEY", os.getenv("OPENAI_API_KEY", "")), + "model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"), + "base_url": os.getenv( + "MOS_EMBEDDER_API_BASE", os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") + ), }, } ) @@ -31,12 +46,12 @@ def get_neo4j_graph(db_name: str = "paper"): config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://xxxx:7687", - "user": "neo4j", - "password": "xxxx", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, "use_multi_db": True, }, ) @@ -49,12 +64,12 @@ def example_multi_db(db_name: str = "paper"): config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, "use_multi_db": True, }, ) @@ -288,14 +303,14 @@ def example_shared_db(db_name: str = "shared-traval-group"): config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_name, "use_multi_db": False, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, }, ) # Step 2: Instantiate graph store @@ -353,12 +368,12 @@ def example_shared_db(db_name: str = "shared-traval-group"): config_alice = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_list[0], - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, }, ) graph_alice = GraphStoreFactory.from_config(config_alice) @@ -382,24 +397,22 @@ def run_user_session( config = GraphDBConfigFactory( backend="neo4j-community", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_name, "use_multi_db": False, - "auto_create": False, # Neo4j Community does not allow auto DB creation - "embedding_dimension": 3072, + "auto_create": False, + "embedding_dimension": EMBEDDING_DIMENSION, "vec_config": { - # Pass nested config to initialize external vector DB - # If you use qdrant, please use Server instead of local mode. "backend": "qdrant", "config": { "collection_name": "neo4j_vec_db", - "vector_dimension": 3072, + "vector_dimension": EMBEDDING_DIMENSION, "distance_metric": "cosine", - "host": "localhost", - "port": 6333, + "host": QDRANT_HOST, + "port": QDRANT_PORT, }, }, }, @@ -408,14 +421,14 @@ def run_user_session( config = GraphDBConfigFactory( backend="neo4j", config={ - "uri": "bolt://localhost:7687", - "user": "neo4j", - "password": "12345678", + "uri": NEO4J_URI, + "user": NEO4J_USER, + "password": NEO4J_PASSWORD, "db_name": db_name, "user_name": user_name, "use_multi_db": False, "auto_create": True, - "embedding_dimension": 3072, + "embedding_dimension": EMBEDDING_DIMENSION, }, ) graph = GraphStoreFactory.from_config(config) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 33eb39692..0ff421f9b 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -39,7 +39,7 @@ def _prepare_node_metadata(metadata: dict[str, Any]) -> dict[str, Any]: metadata["embedding"] = [float(x) for x in embedding] # serialization - if metadata["sources"]: + if metadata.get("sources"): for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) return metadata @@ -226,7 +226,7 @@ def add_node( """ # serialization - if metadata["sources"]: + if metadata.get("sources"): for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) From 0761cb6923b9beae1b6e85958239794f8185cfe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 26 Mar 2026 20:06:10 +0800 Subject: [PATCH 2/6] fix: neo4j pre-filter --- src/memos/graph_dbs/neo4j.py | 8 ++- src/memos/graph_dbs/neo4j_community.py | 70 ++++++++++++++------------ 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 0ff421f9b..43c9aa62b 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -901,14 +901,18 @@ def search_by_embedding( if extra_fields: return_clause = f"RETURN node.id AS id, score, {extra_fields}" + has_post_filter = bool(where_clause) + vector_k = max(top_k * 10, 200) if has_post_filter else top_k + query = f""" CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) YIELD node, score {where_clause} {return_clause} + LIMIT $limit """ - parameters = {"embedding": vector, "k": top_k} + parameters = {"embedding": vector, "k": vector_k, "limit": top_k} if scope: parameters["scope"] = scope @@ -1842,7 +1846,7 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: if not ( isinstance(node["sources"][idx], str) and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + and node["sources"][idx][-1] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) diff --git a/src/memos/graph_dbs/neo4j_community.py b/src/memos/graph_dbs/neo4j_community.py index 283e15115..56f64eae2 100644 --- a/src/memos/graph_dbs/neo4j_community.py +++ b/src/memos/graph_dbs/neo4j_community.py @@ -61,34 +61,35 @@ def add_node( metadata.setdefault("delete_record_id", "") # serialization - if metadata["sources"]: + if metadata.get("sources"): for idx in range(len(metadata["sources"])): metadata["sources"][idx] = json.dumps(metadata["sources"][idx]) # Extract required fields embedding = metadata.pop("embedding", None) - if embedding is None: - raise ValueError(f"Missing 'embedding' in metadata for node {id}") # Merge node and set metadata created_at = metadata.pop("created_at") updated_at = metadata.pop("updated_at") - vector_sync_status = "success" + vector_sync_status = "skipped" - try: - # Write to Vector DB - item = VecDBItem( - id=id, - vector=embedding, - payload={ - "memory": memory, - "vector_sync": vector_sync_status, - **metadata, # unpack all metadata keys to top-level - }, - ) - self.vec_db.add([item]) - except Exception as e: - logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}") - vector_sync_status = "failed" + if embedding is not None: + vector_sync_status = "success" + try: + item = VecDBItem( + id=id, + vector=embedding, + payload={ + "memory": memory, + "vector_sync": vector_sync_status, + **metadata, + }, + ) + self.vec_db.add([item]) + except Exception as e: + logger.warning(f"[VecDB] Vector insert failed for node {id}: {e}") + vector_sync_status = "failed" + else: + logger.warning(f"[add_node] No embedding for node {id}, skipping vector DB insert") metadata["vector_sync"] = vector_sync_status query = """ @@ -141,18 +142,21 @@ def add_nodes_batch(self, nodes: list[dict[str, Any]], user_name: str | None = N embedding = metadata.pop("embedding", None) - vector_sync_status = "success" - vec_items.append( - VecDBItem( - id=node_id, - vector=embedding, - payload={ - "memory": memory, - "vector_sync": vector_sync_status, - **metadata, - }, + if embedding is not None: + vector_sync_status = "success" + vec_items.append( + VecDBItem( + id=node_id, + vector=embedding, + payload={ + "memory": memory, + "vector_sync": vector_sync_status, + **metadata, + }, + ) ) - ) + else: + vector_sync_status = "skipped" created_at = metadata.pop("created_at") updated_at = metadata.pop("updated_at") @@ -1138,12 +1142,12 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]: node[time_field] = node[time_field].isoformat() node.pop("user_name", None) # serialization - if node["sources"]: + if node.get("sources"): for idx in range(len(node["sources"])): if not ( isinstance(node["sources"][idx], str) and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + and node["sources"][idx][-1] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) @@ -1179,7 +1183,7 @@ def _parse_nodes(self, nodes_data: list[dict[str, Any]]) -> list[dict[str, Any]] if not ( isinstance(node["sources"][idx], str) and node["sources"][idx][0] == "{" - and node["sources"][idx][0] == "}" + and node["sources"][idx][-1] == "}" ): break node["sources"][idx] = json.loads(node["sources"][idx]) From 5a8cc86a35652c47544ce5a0d0edcecf66ee02f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 26 Mar 2026 20:49:31 +0800 Subject: [PATCH 3/6] fix: neo4j pre-filter --- docker/docker-compose.yml | 2 +- src/memos/graph_dbs/neo4j.py | 51 +++++++++++++++++++++++------------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 0a8e2c634..6805ec781 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -26,7 +26,7 @@ services: - memos_network neo4j: - image: neo4j:5.26.4 + image: neo4j:5.26.6 container_name: neo4j-docker ports: - "7474:7474" # HTTP diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 43c9aa62b..57928e2fa 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -73,7 +73,10 @@ def _flatten_info_fields(metadata: dict[str, Any]) -> dict[str, Any]: class Neo4jGraphDB(BaseGraphDB): - """Neo4j-based implementation of a graph memory store.""" + """Neo4j-based implementation of a graph memory store. + + Requires Neo4j >= 5.18 for vector.similarity.cosine() pre-filtering support. + """ @require_python_package( import_name="neo4j", @@ -843,13 +846,14 @@ def search_by_embedding( If return_fields is specified, each dict also includes the requested fields. Notes: - - This method uses Neo4j native vector indexing to search for similar nodes. - - If scope is provided, it restricts results to nodes with matching memory_type. - - If 'status' is provided, only nodes with the matching status will be returned. + - When filters are present (scope, status, user_name, etc.), this method uses + Neo4j 5.18+ pre-filtering: MATCH + WHERE narrows candidates first, then + vector.similarity.cosine() computes similarity only on the filtered set. + This avoids the post-filter problem where queryNodes' global top-k excludes + the target user's nodes in a multi-tenant shared database. + - When no filters are present, the ANN vector index (db.index.vector.queryNodes) + is used for maximum efficiency. - If threshold is provided, only results with score >= threshold will be returned. - - If search_filter is provided, additional WHERE clauses will be added for metadata filtering. - - Typical use case: restrict to 'status = activated' to avoid - matching archived or merged nodes. """ user_name = user_name if user_name else self.config.user_name # Build WHERE clause dynamically @@ -901,18 +905,29 @@ def search_by_embedding( if extra_fields: return_clause = f"RETURN node.id AS id, score, {extra_fields}" - has_post_filter = bool(where_clause) - vector_k = max(top_k * 10, 200) if has_post_filter else top_k - - query = f""" - CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) - YIELD node, score - {where_clause} - {return_clause} - LIMIT $limit - """ + if where_clause: + # Pre-filtering (Neo4j 5.18+): filter nodes first, then compute similarity. + # This avoids the post-filter problem where relevant nodes are excluded + # from the global top-k returned by queryNodes. + where_clause += " AND node.embedding IS NOT NULL" + query = f""" + MATCH (node:Memory) + {where_clause} + WITH node, vector.similarity.cosine(node.embedding, $embedding) AS score + {return_clause} + ORDER BY score DESC + LIMIT $limit + """ + else: + # No filter: use ANN vector index for efficiency. + query = f""" + CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) + YIELD node, score + {return_clause} + LIMIT $limit + """ - parameters = {"embedding": vector, "k": vector_k, "limit": top_k} + parameters = {"embedding": vector, "k": top_k, "limit": top_k} if scope: parameters["scope"] = scope From 6f954ad2f30d5f071ec8dee0fef8ec088d9d9c18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 26 Mar 2026 21:00:36 +0800 Subject: [PATCH 4/6] fix: neo4j pre-filter --- src/memos/graph_dbs/neo4j.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index 57928e2fa..930b19f6f 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -916,18 +916,17 @@ def search_by_embedding( WITH node, vector.similarity.cosine(node.embedding, $embedding) AS score {return_clause} ORDER BY score DESC - LIMIT $limit + LIMIT $top_k """ + parameters = {"embedding": vector, "top_k": top_k} else: # No filter: use ANN vector index for efficiency. query = f""" - CALL db.index.vector.queryNodes('memory_vector_index', $k, $embedding) + CALL db.index.vector.queryNodes('memory_vector_index', $top_k, $embedding) YIELD node, score {return_clause} - LIMIT $limit """ - - parameters = {"embedding": vector, "k": top_k, "limit": top_k} + parameters = {"embedding": vector, "top_k": top_k} if scope: parameters["scope"] = scope From fad379d53b4069fe758f332871903922017d15b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 26 Mar 2026 21:03:36 +0800 Subject: [PATCH 5/6] test: add regression tests for Neo4j vector search pre-filter and sources KeyError Made-with: Cursor --- tests/graph_dbs/test_neo4j_vector_search.py | 425 ++++++++++++++++++++ 1 file changed, 425 insertions(+) create mode 100644 tests/graph_dbs/test_neo4j_vector_search.py diff --git a/tests/graph_dbs/test_neo4j_vector_search.py b/tests/graph_dbs/test_neo4j_vector_search.py new file mode 100644 index 000000000..3ed0b0587 --- /dev/null +++ b/tests/graph_dbs/test_neo4j_vector_search.py @@ -0,0 +1,425 @@ +""" +Tests for Neo4j vector search pre-filtering and related regressions. + +- Unit tests: verify query structure (pre-filter vs ANN paths) using mocks +- Integration tests: verify real search behavior with multi-user data (requires Neo4j 5.18+) + +The pre-filter approach (Neo4j 5.18+): + When WHERE filters are present (scope, status, user_name, etc.), the query uses + MATCH + WHERE to narrow candidates first, then vector.similarity.cosine() + computes similarity only on the filtered set. This avoids the post-filter + problem entirely — no nodes are lost due to global top-k truncation. + + When no filters are present, the ANN vector index (db.index.vector.queryNodes) + is used for maximum efficiency. +""" + +import math +import os +import uuid + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from memos.configs.graph_db import Neo4jGraphDBConfig + + +# ────────────────────────────────────────────────────────────────────────────── +# Fixtures for unit tests (mocked Neo4j driver) +# ────────────────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def shared_db_config(): + """Shared-database multi-tenant config (use_multi_db=False).""" + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_db", + auto_create=False, + use_multi_db=False, + user_name="default_user", + embedding_dimension=3, + ) + + +@pytest.fixture +def multi_db_config(): + """Multi-database config — no user_name filter in queries.""" + return Neo4jGraphDBConfig( + uri="bolt://localhost:7687", + user="neo4j", + password="test", + db_name="test_db", + auto_create=False, + use_multi_db=True, + embedding_dimension=3, + ) + + +@pytest.fixture +def shared_neo4j_db(shared_db_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(shared_db_config) + db.driver = mock_driver + yield db + + +@pytest.fixture +def multi_neo4j_db(multi_db_config): + with patch("neo4j.GraphDatabase") as mock_gd: + mock_driver = MagicMock() + mock_gd.driver.return_value = mock_driver + from memos.graph_dbs.neo4j import Neo4jGraphDB + + db = Neo4jGraphDB(multi_db_config) + db.driver = mock_driver + yield db + + +# ────────────────────────────────────────────────────────────────────────────── +# Unit tests: pre-filter vs ANN query paths +# ────────────────────────────────────────────────────────────────────────────── + + +class TestVectorSearchPreFilter: + """Verify pre-filter path uses MATCH + vector.similarity.cosine() + and ANN path uses db.index.vector.queryNodes.""" + + def test_prefilter_with_scope(self, shared_neo4j_db): + """With scope filter, query should use MATCH + cosine similarity, not queryNodes.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + scope="LongTermMemory", + ) + + query = session_mock.run.call_args[0][0] + assert "MATCH (node:Memory)" in query + assert "vector.similarity.cosine(node.embedding, $embedding)" in query + assert "queryNodes" not in query + + def test_prefilter_with_all_filters(self, shared_neo4j_db): + """With scope + status + user_name, all filters appear in WHERE before similarity.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=10, + scope="LongTermMemory", + status="activated", + user_name="some_user", + ) + + query = session_mock.run.call_args[0][0] + assert "MATCH (node:Memory)" in query + assert "node.memory_type = $scope" in query + assert "node.status = $status" in query + assert "node.user_name = $user_name" in query + assert "vector.similarity.cosine" in query + + def test_prefilter_includes_embedding_not_null(self, shared_neo4j_db): + """Pre-filter query should exclude nodes without embeddings.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + scope="LongTermMemory", + ) + + query = session_mock.run.call_args[0][0] + assert "node.embedding IS NOT NULL" in query + + def test_prefilter_has_order_by_and_limit(self, shared_neo4j_db): + """Pre-filter results should be ordered by score and limited.""" + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + shared_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + scope="LongTermMemory", + ) + + query = session_mock.run.call_args[0][0] + assert "ORDER BY score DESC" in query + assert "LIMIT $top_k" in query + params = session_mock.run.call_args[0][1] + assert params["top_k"] == 5 + + def test_ann_path_without_filters(self, multi_neo4j_db): + """Without any filter, query should use queryNodes ANN index.""" + session_mock = multi_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + multi_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + ) + + query = session_mock.run.call_args[0][0] + assert "queryNodes" in query + assert "$top_k" in query + assert "MATCH (node:Memory)" not in query + params = session_mock.run.call_args[0][1] + assert params["top_k"] == 5 + + def test_ann_path_no_redundant_params(self, multi_neo4j_db): + """ANN path should only have embedding and top_k, nothing extra.""" + session_mock = multi_neo4j_db.driver.session.return_value.__enter__.return_value + session_mock.run.return_value = [] + + multi_neo4j_db.search_by_embedding( + vector=[0.1, 0.2, 0.3], + top_k=5, + ) + + params = session_mock.run.call_args[0][1] + assert set(params.keys()) == {"embedding", "top_k"} + + +# ────────────────────────────────────────────────────────────────────────────── +# Unit tests: sources KeyError regression +# ────────────────────────────────────────────────────────────────────────────── + + +class TestSourcesKeyErrorRegression: + """Verify that missing/None 'sources' key doesn't cause KeyError.""" + + def test_add_node_without_sources_key(self, shared_neo4j_db): + session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + + shared_neo4j_db.add_node( + id="test-node-1", + memory="test content", + metadata={ + "memory_type": "WorkingMemory", + "embedding": [0.1, 0.2, 0.3], + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + ) + + calls = session_mock.run.call_args_list + assert any("MERGE (n:Memory" in str(call) for call in calls) + + def test_add_node_with_empty_sources(self, shared_neo4j_db): + _session_mock = shared_neo4j_db.driver.session.return_value.__enter__.return_value + + shared_neo4j_db.add_node( + id="test-node-2", + memory="test content", + metadata={ + "memory_type": "WorkingMemory", + "embedding": [0.1, 0.2, 0.3], + "sources": [], + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + ) + + def test_parse_node_without_sources_key(self, shared_neo4j_db): + result = shared_neo4j_db._parse_node( + { + "id": "node-1", + "memory": "hello", + "memory_type": "WorkingMemory", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + assert result["id"] == "node-1" + assert result["memory"] == "hello" + + +# ────────────────────────────────────────────────────────────────────────────── +# Integration tests (require a running Neo4j 5.18+ with vector index) +# +# Activate by setting environment variables: +# NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD +# +# Run: +# pytest tests/graph_dbs/test_neo4j_vector_search.py -k Integration -v +# ────────────────────────────────────────────────────────────────────────────── + + +def _neo4j_package_available(): + try: + import neo4j # noqa: F401 + + return True + except ImportError: + return False + + +_neo4j_configured = _neo4j_package_available() and all( + os.getenv(k) for k in ("NEO4J_URI", "NEO4J_USER", "NEO4J_PASSWORD") +) +_TEST_RUN_ID = uuid.uuid4().hex[:8] +_TARGET_USER = f"__test_target_{_TEST_RUN_ID}" +_OTHER_USER_PREFIX = f"__test_other_{_TEST_RUN_ID}" + + +def _make_unit_vector( + dim: int, dominant_axis: int, secondary_axis: int | None = None +) -> list[float]: + """ + Create a unit vector concentrated on one axis, optionally blended with a second. + + Used to control cosine similarity in tests: + - Two vectors on the same axis → cos_sim ≈ 1.0 + - Orthogonal axes → cos_sim ≈ 0.0 + - Blended → cos_sim ≈ 0.707 + """ + vec = [0.0] * dim + vec[dominant_axis % dim] = 1.0 + if secondary_axis is not None: + vec[secondary_axis % dim] = 1.0 + norm = math.sqrt(sum(x * x for x in vec)) + return [x / norm for x in vec] + + +@pytest.fixture(scope="module") +def integration_config(): + if not _neo4j_configured: + pytest.skip("Neo4j not configured (need NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)") + return Neo4jGraphDBConfig( + uri=os.getenv("NEO4J_URI"), + user=os.getenv("NEO4J_USER"), + password=os.getenv("NEO4J_PASSWORD"), + db_name=os.getenv("NEO4J_DB_NAME", "neo4j"), + auto_create=False, + use_multi_db=False, + user_name=f"__test_default_{_TEST_RUN_ID}", + embedding_dimension=int(os.getenv("EMBEDDING_DIMENSION", "1536")), + ) + + +@pytest.fixture(scope="module") +def integration_db(integration_config): + from memos.graph_dbs.neo4j import Neo4jGraphDB + + return Neo4jGraphDB(integration_config) + + +@pytest.mark.skipif(not _neo4j_configured, reason="Neo4j not configured") +class TestNeo4jPreFilterIntegration: + """ + Integration test: pre-filtered vector search in a multi-user shared database. + + Uses vector.similarity.cosine() with MATCH + WHERE to pre-filter by user, + guaranteeing that target user's nodes are always considered regardless of + how many other users' nodes exist in the database. + """ + + @pytest.fixture(scope="class", autouse=True) + def seed_and_cleanup(self, integration_db, integration_config): + """ + Seed multi-user test data, then clean up. + + - 50 "other" user nodes: embeddings along axis 0 → cos_sim ≈ 1.0 with query + - 3 "target" user nodes: embeddings blended axis 0+1 → cos_sim ≈ 0.707 with query + + With pre-filtering, only the target user's 3 nodes are candidates for + similarity computation, so all 3 are always returned. + """ + dim = integration_config.embedding_dimension + now = datetime.now(timezone.utc).isoformat() + + for i in range(50): + other_user = f"{_OTHER_USER_PREFIX}_{i % 10}" + integration_db.add_node( + id=f"__test_other_{_TEST_RUN_ID}_{i}", + memory=f"Other user memory {i}", + metadata={ + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": _make_unit_vector(dim, dominant_axis=0), + "created_at": now, + "updated_at": now, + }, + user_name=other_user, + ) + + for i in range(3): + integration_db.add_node( + id=f"__test_target_{_TEST_RUN_ID}_{i}", + memory=f"Target user memory {i}", + metadata={ + "memory_type": "LongTermMemory", + "status": "activated", + "embedding": _make_unit_vector(dim, dominant_axis=0, secondary_axis=1), + "created_at": now, + "updated_at": now, + }, + user_name=_TARGET_USER, + ) + + yield + + integration_db.clear(user_name=_TARGET_USER) + for i in range(10): + integration_db.clear(user_name=f"{_OTHER_USER_PREFIX}_{i}") + + def test_search_returns_all_target_user_results(self, integration_db, integration_config): + """Pre-filtering guarantees all target user nodes are found.""" + dim = integration_config.embedding_dimension + query_vector = _make_unit_vector(dim, dominant_axis=0) + + results = integration_db.search_by_embedding( + vector=query_vector, + top_k=3, + scope="LongTermMemory", + status="activated", + user_name=_TARGET_USER, + ) + + assert len(results) == 3, ( + f"Pre-filter should return all 3 target user nodes, got {len(results)}. " + "This indicates pre-filtering is not working correctly." + ) + + def test_all_returned_ids_belong_to_target_user(self, integration_db, integration_config): + dim = integration_config.embedding_dimension + query_vector = _make_unit_vector(dim, dominant_axis=0) + + results = integration_db.search_by_embedding( + vector=query_vector, + top_k=3, + scope="LongTermMemory", + status="activated", + user_name=_TARGET_USER, + ) + + for r in results: + assert r["id"].startswith(f"__test_target_{_TEST_RUN_ID}_"), ( + f"Result {r['id']} does not belong to the target user" + ) + + def test_scores_are_positive(self, integration_db, integration_config): + dim = integration_config.embedding_dimension + query_vector = _make_unit_vector(dim, dominant_axis=0) + + results = integration_db.search_by_embedding( + vector=query_vector, + top_k=3, + scope="LongTermMemory", + status="activated", + user_name=_TARGET_USER, + ) + + for r in results: + assert r["score"] > 0, f"Score should be positive, got {r['score']}" From 9e7ec1bc869430b5e615a7639ede9135e0ee2db0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Thu, 26 Mar 2026 21:08:54 +0800 Subject: [PATCH 6/6] style: fix ruff formatting in server_api.py Made-with: Cursor --- src/memos/api/server_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 38be6d007..1f1e6ccde 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -39,6 +39,7 @@ def health_check(): "version": app.version, } + # Request validation failed app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) # Invalid business code parameters