diff --git a/examples/basic_modules/nebular_example.py b/examples/basic_modules/nebular_example.py index 2f591330d..13f88e3f3 100644 --- a/examples/basic_modules/nebular_example.py +++ b/examples/basic_modules/nebular_example.py @@ -52,56 +52,6 @@ def embed_memory_item(memory: str) -> list[float]: return embedding_list -def example_multi_db(db_name: str = "paper"): - # Step 1: Build factory config - config = GraphDBConfigFactory( - backend="nebular", - config={ - "uri": json.loads(os.getenv("NEBULAR_HOSTS", "localhost")), - "user": os.getenv("NEBULAR_USER", "root"), - "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"), - "space": db_name, - "use_multi_db": True, - "auto_create": True, - "embedding_dimension": embedder_dimension, - }, - ) - - # Step 2: Instantiate the graph store - graph = GraphStoreFactory.from_config(config) - graph.clear() - - # Step 3: Create topic node - topic = TextualMemoryItem( - memory="This research addresses long-term multi-UAV navigation for energy-efficient communication coverage.", - metadata=TreeNodeTextualMemoryMetadata( - memory_type="LongTermMemory", - key="Multi-UAV Long-Term Coverage", - hierarchy_level="topic", - type="fact", - memory_time="2024-01-01", - source="file", - sources=["paper://multi-uav-coverage/intro"], - status="activated", - confidence=95.0, - tags=["UAV", "coverage", "multi-agent"], - entities=["UAV", "coverage", "navigation"], - visibility="public", - updated_at=datetime.now().isoformat(), - embedding=embed_memory_item( - "This research addresses long-term " - "multi-UAV navigation for " - "energy-efficient communication " - "coverage." - ), - ), - ) - - graph.add_node( - id=topic.id, memory=topic.memory, metadata=topic.metadata.model_dump(exclude_none=True) - ) - - def example_shared_db(db_name: str = "shared-traval-group"): """ Example: Single(Shared)-DB multi-tenant (logical isolation) @@ -404,9 +354,6 @@ def example_complex_shared_db(db_name: str = "shared-traval-group-complex"): if __name__ == "__main__": - print("\n=== Example: Multi-DB ===") - example_multi_db(db_name="paper-new") - print("\n=== Example: Single-DB ===") example_shared_db(db_name="shared_traval_group-new") diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 0eee05e74..38f08ff8d 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -169,7 +169,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) else "embedding" ) - tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space + tmp.system_db_name = cfg.space tmp._client = client tmp._owns_client = False return tmp @@ -364,7 +364,7 @@ def __init__(self, config: NebulaGraphDBConfig): if (str(self.embedding_dimension) != str(self.default_memory_dimension)) else "embedding" ) - self.system_db_name = "system" if config.use_multi_db else config.space + self.system_db_name = config.space # ---- NEW: pool acquisition strategy # Get or create a shared pool from the class-level cache @@ -439,15 +439,13 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. """ - optional_condition = "" - if not self.config.use_multi_db and self.config.user_name: - optional_condition = f"AND n.user_name = '{self.config.user_name}'" + optional_condition = f"AND n.user_name = '{self.config.user_name}'" try: count = self.count_nodes(memory_type) if count > keep_latest: delete_query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.memory_type = '{memory_type}' {optional_condition} ORDER BY n.updated_at DESC @@ -463,8 +461,7 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: """ Insert or update a Memory node in NebulaGraph. """ - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + metadata["user_name"] = self.config.user_name now = datetime.utcnow() metadata = metadata.copy() @@ -495,12 +492,9 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @timed def node_not_exist(self, scope: str) -> int: - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' - else: - filter_clause = f'n.memory_type = "{scope}"' + filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {filter_clause} RETURN n.id AS id LIMIT 1 @@ -529,8 +523,7 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: MATCH (n@Memory {{id: "{id}"}}) """ - if not self.config.use_multi_db and self.config.user_name: - query += f'WHERE n.user_name = "{self.config.user_name}"' + query += f'WHERE n.user_name = "{self.config.user_name}"' query += f"\nSET {set_clause_str}" self.execute_query(query) @@ -545,9 +538,8 @@ def delete_node(self, id: str) -> None: query = f""" MATCH (n@Memory {{id: "{id}"}}) """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" WHERE n.user_name = {self._format_value(user_name)}" + user_name = self.config.user_name + query += f" WHERE n.user_name = {self._format_value(user_name)}" query += "\n DETACH DELETE n" self.execute_query(query) @@ -563,9 +555,7 @@ def add_edge(self, source_id: str, target_id: str, type: str): if not source_id or not target_id: raise ValueError("[add_edge] source_id and target_id must be provided") - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{self.config.user_name}"}}' insert_stmt = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) @@ -590,9 +580,8 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" + user_name = self.config.user_name + query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" query += "\nDELETE r" self.execute_query(query) @@ -603,9 +592,8 @@ def get_memory_count(self, memory_type: str) -> int: MATCH (n@Memory) WHERE n.memory_type = "{memory_type}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + user_name = self.config.user_name + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN COUNT(n) AS count" try: @@ -622,9 +610,8 @@ def count_nodes(self, scope: str | None = None) -> int: if scope: conditions.append(f'n.memory_type = "{scope}"') - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - conditions.append(f"n.user_name = '{user_name}'") + user_name = self.config.user_name + conditions.append(f"n.user_name = '{user_name}'") if conditions: query += "\nWHERE " + " AND ".join(conditions) @@ -664,9 +651,8 @@ def edge_exists( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." ) query = f"MATCH {pattern}" - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + user_name = self.config.user_name + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query += "\nRETURN r" # Run the Cypher query @@ -689,10 +675,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | Returns: dict: Node properties as key-value pairs, or None if not found. """ - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' - else: - filter_clause = f'n.id = "{id}"' + filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" @@ -733,12 +716,10 @@ def get_nodes( if not ids: return [] - where_user = "" - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_user = f" AND n.user_name = '{kwargs['cube_name']}'" - else: - where_user = f" AND n.user_name = '{self.config.user_name}'" + if kwargs.get("cube_name"): + where_user = f" AND n.user_name = '{kwargs['cube_name']}'" + else: + where_user = f" AND n.user_name = '{self.config.user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) @@ -794,8 +775,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" + where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" query = f""" MATCH {pattern} @@ -848,8 +828,7 @@ def get_neighbors_by_tag( if exclude_ids: where_clauses.append(f"NOT (n.id IN {exclude_ids})") - if not self.config.use_multi_db and self.config.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{self.config.user_name}"') where_clause = " AND ".join(where_clauses) tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" @@ -858,7 +837,7 @@ def get_neighbors_by_tag( query = f""" LET tag_list = {tag_list_literal} - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_clause} RETURN {return_fields}, size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count @@ -884,11 +863,8 @@ def get_neighbors_by_tag( @timed def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: - where_user = "" - - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + user_name = self.config.user_name + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" query = f""" MATCH (p@Memory)-[@PARENT]->(c@Memory) @@ -1014,19 +990,18 @@ def search_by_embedding( where_clauses.append(f'n.memory_type = "{scope}"') if status: where_clauses.append(f'n.status = "{status}"') - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') - else: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + if kwargs.get("cube_name"): + where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') + else: + where_clauses.append(f'n.user_name = "{self.config.user_name}"') - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append(f'n.{key} = "{value}"') + else: + where_clauses.append(f"n.{key} = {value}") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" @@ -1107,11 +1082,10 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{self.config.user_name}"') where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" + gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" ids = [] try: result = self.execute_query(gql) @@ -1143,16 +1117,15 @@ def get_grouped_counts( raise ValueError("group_fields cannot be empty") # GQL-specific modifications - if not self.config.use_multi_db and self.config.user_name: - user_clause = f"n.user_name = '{self.config.user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" + user_clause = f"n.user_name = '{self.config.user_name}'" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" else: - where_clause = f"WHERE {user_clause}" + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" # Inline parameters if provided if params: @@ -1195,10 +1168,9 @@ def clear(self) -> None: Clear the entire graph if the target database exists. """ try: - if not self.config.use_multi_db and self.config.user_name: - query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" - else: - query = "MATCH (n) DETACH DELETE n" + query = ( + f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" + ) self.execute_query(query) logger.info("Cleared all nodes from database.") @@ -1222,10 +1194,9 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: node_query = "MATCH (n@Memory)" edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - if not self.config.use_multi_db and self.config.user_name: - username = self.config.user_name - node_query += f' WHERE n.user_name = "{username}"' - edge_query += f' WHERE r.user_name = "{username}"' + username = self.config.user_name + node_query += f' WHERE n.user_name = "{username}"' + edge_query += f' WHERE r.user_name = "{username}"' try: if include_embedding: @@ -1296,8 +1267,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: id, memory, metadata = _compose_node(node) - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + metadata["user_name"] = self.config.user_name metadata = self._prepare_node_metadata(metadata) metadata.update({"id": id, "memory": memory}) @@ -1313,9 +1283,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: source_id, target_id = edge["source"], edge["target"] edge_type = edge["type"] - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{self.config.user_name}"}}' edge_gql = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) @@ -1340,9 +1308,7 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = f"WHERE n.memory_type = '{scope}'" - - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND n.user_name = '{self.config.user_name}'" + where_clause += f" AND n.user_name = '{self.config.user_name}'" return_fields = self._build_return_fields(include_embedding) @@ -1376,8 +1342,7 @@ def get_structure_optimization_candidates( n.memory_type = "{scope}" AND n.status = "activated" ''' - if not self.config.use_multi_db and self.config.user_name: - where_clause += f' AND n.user_name = "{self.config.user_name}"' + where_clause += f' AND n.user_name = "{self.config.user_name}"' return_fields = self._build_return_fields(include_embedding) return_fields += f", n.{self.dim_field} AS {self.dim_field}" @@ -1412,14 +1377,10 @@ def drop_database(self) -> None: Permanently delete the entire database this instance is using. WARNING: This operation is destructive and cannot be undone. """ - if self.config.use_multi_db: - self.execute_query(f"DROP GRAPH `{self.db_name}`") - logger.info(f"Database '`{self.db_name}`' has been dropped.") - else: - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) + raise ValueError( + f"Refusing to drop protected database: `{self.db_name}` in " + f"Shared Database Multi-Tenant mode" + ) @timed def detect_conflicts(self) -> list[tuple[str, str]]: @@ -1624,9 +1585,7 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at"] - if not self.config.use_multi_db: - fields.append("user_name") + fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] for field in fields: index_name = f"idx_memory_{field}"